From a7c3b3cda854535b9d657e366b5c07fae8e02894 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 2 Jul 2021 02:46:49 +0800 Subject: [PATCH] [TIR][TVMScript] specialize (#8354) --- include/tvm/tir/analysis.h | 2 +- include/tvm/tir/buffer.h | 1 + include/tvm/tir/function.h | 38 +++ python/tvm/tir/function.py | 55 ++- src/tir/ir/specialize.cc | 337 +++++++++++++++++++ tests/python/unittest/test_tir_specialize.py | 199 +++++++++++ 6 files changed, 630 insertions(+), 2 deletions(-) create mode 100644 src/tir/ir/specialize.cc create mode 100644 tests/python/unittest/test_tir_specialize.py diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 262ac688f2e0..63d6fa375c83 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -96,7 +96,7 @@ TVM_DLL Array UndefinedVars(const PrimExpr& expr); TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr); /*! - * \brief Whether e expression used any var in variable set.. + * \brief Whether e expression used any var in variable set. * \param expr The expression to be checked. * \param vset_contains The check function to see if var is in the vset. * \return Whether e uses vset. diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index a01d69b372d2..017f4f7052b1 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -183,6 +183,7 @@ class Buffer : public ObjectRef { TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode); }; /*! diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 97ee7f7211d4..25ed2f9ae8d1 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -187,6 +187,44 @@ class LinkedParam : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode); }; +/*! + * \brief Specialize parameters of PrimFunc. + * \param func The PrimFunc to be specialized. + * \param param_map The mapping from function params to the instance. + * \return The new function with parameter specialized. + * \note We can define a Meta TIR function with symbolic shape: + * + * \code + * @tvm.script.tir + * def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None: + * A = tir.match_buffer(a, (m, n), "float32") + * B = tir.match_buffer(b, (m, n), "float32") + * + * with tir.block([m, n], "") as [vi, vj]: + * B[vi, vj] = A[vi, vj] + * \endcode + * + * Then we can make it specialized with given shapes or buffers. + * + * \code + * a, _, m, n = mem_copy.params + * func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) + * # or + * func = mem_copy.specialize({n: 16, m: 16}) + * \endcode + * + * \code {.language-id} + * @tvm.script.tir + * def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: + * A = tir.match_buffer(a, (16, 16), "float32") + * B = tir.match_buffer(b, (16, 16), "float32") + * + * with tir.block([16, 16], "") as [vi, vj]: + * B[vi, vj] = A[vi, vj] + * \endcode + */ +PrimFunc Specialize(PrimFunc func, const Map& param_map); + /*! * \brief PrimFunc specific attribute names. * diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 79d18d8970b5..b1081d436150 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -16,12 +16,14 @@ # under the License. """Function data types.""" +from typing import Mapping, Union + import tvm._ffi import tvm.runtime from tvm.runtime import Object from tvm.ir import BaseFunc from .buffer import Buffer -from .expr import Var +from .expr import Var, PrimExpr from . import _ffi_api @@ -85,3 +87,54 @@ def with_body(self, new_body, span=None): The created new function. """ return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span) + + def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]): + """Specialize parameters of PrimFunc + + Parameters + ---------- + + param_map : Mapping[Var, Union[PrimExpr, Buffer]] + The mapping from function params to the instance + + Examples + -------- + We can define a Meta TIR function with symbolic shape: + + .. code-block:: python + + @tvm.script.tir + def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None: + A = tir.match_buffer(a, (m, n), "float32") + B = tir.match_buffer(b, (m, n), "float32") + + with tir.block([m, n], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + Then we can make it specialized with given shapes or buffers. + + .. code-block:: python + + a, _, m, n = mem_copy.params + func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) + # or + func = mem_copy.specialize({n: 16, m: 16}) + + The specialized function: + + .. code-block:: python + + @tvm.script.tir + def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + + with tir.block([16, 16], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + Returns + ------- + func : PrimFunc + The new function with parameter specialized + """ + return _ffi_api.Specialize(self, param_map) diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc new file mode 100644 index 000000000000..aa5f271c20c2 --- /dev/null +++ b/src/tir/ir/specialize.cc @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tir/ir/specialize.cc + * \brief Specialize parameters of PrimFunc. + */ +#include +#include +#include +#include + +#include + +#include "functor_common.h" + +namespace tvm { +namespace tir { + +using VarMap = std::unordered_map; + +/**************** Helper functions ****************/ + +/*! \brief Helper function to check whether the given var is in function parameter list. */ +inline bool IsParam(const PrimFunc& func, const Var& param) { + return std::any_of(func->params.begin(), func->params.end(), + [&](const Var& var) { return var.same_as(param); }); +} + +/**************** Specializer ****************/ + +/*! \brief Mutator to specialize function and remove const parameters */ +class PrimFuncSpecializer : public StmtExprMutator { + public: + explicit PrimFuncSpecializer(const VarMap& var_map) : var_map_(var_map) {} + + static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) { + PrimFuncSpecializer specializer(var_map); + // Updating Buffer map + Map buffer_map; + bool buffer_map_updated = false; + for (const auto& it : f->buffer_map) { + const Var& var = it.first; + const Buffer& buffer = it.second; + Buffer new_buffer = specializer.MutateBuffer(buffer); + buffer_map.Set(var, new_buffer); + if (!new_buffer.same_as(buffer)) { + buffer_map_updated = true; + specializer.buffer_map_[buffer] = new_buffer; + } + } + + // Updating parmeters + Array params; + bool param_updated = false; + for (const auto& var : f->params) { + // Remove parmeters which has been specialized. + if (var_map.find(var) == var_map.end()) { + params.push_back(var); + } else { + param_updated = true; + } + } + + // Updating function body + Stmt body = specializer(f->body); + + if (param_updated || buffer_map_updated || !f->body.same_as(body)) { + PrimFuncNode* f_ptr = f.CopyOnWrite(); + f_ptr->params = std::move(params); + f_ptr->buffer_map = std::move(buffer_map); + f_ptr->body = std::move(body); + } + return f; + } + + private: + Stmt VisitStmt_(const BlockNode* op) final { + // Step.0. Define buffer mappings which is allocated inside the block + Array alloc_buffers = MutateArray( + op->alloc_buffers, + std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1)); + + // Step.1. Recursively visit block body + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + + Array reads = MutateArray( + op->reads, + std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); + Array writes = MutateArray( + op->writes, + std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); + + if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads)) { + return GetRef(op); + } else { + ObjectPtr n = CopyOnWrite(op); + n->alloc_buffers = std::move(alloc_buffers); + n->reads = std::move(reads); + n->writes = std::move(writes); + return Stmt(n); + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + auto it = buffer_map_.find(op->buffer); + if (it == buffer_map_.end()) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->buffer = it->second; + return Stmt(n); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK(op != nullptr); + auto it = buffer_map_.find(op->buffer); + if (it == buffer_map_.end()) { + return GetRef(op); + } else { + auto n = make_object(*op); + n->buffer = it->second; + return PrimExpr(n); + } + } + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = var_map_.find(GetRef(op)); + if (it == var_map_.end()) { + return GetRef(op); + } else { + return it->second; + } + } + + private: + Buffer MutateBuffer(const Buffer& buffer) const { + Array shape = + MutateArray(buffer->shape, [this](const PrimExpr& e) { return Substitute(e, var_map_); }); + Array strides = + MutateArray(buffer->strides, [this](const PrimExpr& e) { return Substitute(e, var_map_); }); + + PrimExpr elem_offset = Substitute(buffer->elem_offset, var_map_); + + if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) && + buffer->strides.same_as(strides)) { + return buffer; + } else { + auto n = make_object(*buffer.get()); + n->elem_offset = std::move(elem_offset); + n->shape = std::move(shape); + n->strides = std::move(strides); + return Buffer(n); + } + } + + Range MutateRange(const Range& range) { + PrimExpr min = this->VisitExpr(range->min); + PrimExpr extent = this->VisitExpr(range->extent); + if (min.same_as(range->min) && extent.same_as(range->extent)) { + return range; + } else { + return Range::FromMinExtent(std::move(min), std::move(extent)); + } + } + + Buffer MutateAllocBuffer(const Buffer& alloc_buf) { + Buffer buf = MutateBuffer(alloc_buf); + if (buf.same_as(alloc_buf)) { + return alloc_buf; + } else { + ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end()); + buffer_map_[alloc_buf] = buf; + return buf; + } + } + + BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) { + auto it = buffer_map_.find(buffer_region->buffer); + Array region = + MutateArray(buffer_region->region, + std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1)); + if (it == buffer_map_.end() && region.same_as(buffer_region->region)) { + return buffer_region; + } else { + return BufferRegion(it->second, std::move(region)); + } + } + + private: + /*! \brief The vars to be substitute and their values */ + const VarMap& var_map_; + /*! \brief map from old buffer to mutated buffer */ + std::unordered_map buffer_map_; +}; + +/*! + * \brief Update Specialize var map with buffer matching. + * \param func The function to be specialized. + * \param param The given function parameter + * \param specific_buf The matching buffer. + * \param var_map The var mapping to be updated. + * \note This function will match target buffer's shape, strides and element_offset + * For example, we define a buffer in PrimFunc: + * A = tir.match_buffer(a, [m, n]) + * + * Then we match it with a buffer B = tir.decl_buffer((8, 16)) + * + * It means we have two var mappings here: m = 8 and n = 16 + * + * If the buffer signature is not a Var, the mapping will fail. + * e.g. A = tir.match_buffer(a, [m * 2, n + 1]) + */ +void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer& specific_buf, + VarMap* var_map) { + // preliminaries + tir::ExprDeepEqual equal; + + auto it = func->buffer_map.find(param); + CHECK(it != func->buffer_map.end()) + << "ValueError: specialize expects param to be in PrimFunc's buffer_map"; + const Buffer& buf_to_specialize = (*it).second; + + // build var mapping using specific_buf's parameters + auto build_var_mapping = [&](const PrimExpr& new_expr, const PrimExpr& old_expr) { + if (!equal(new_expr, old_expr)) { + CHECK(old_expr->IsInstance()) + << "TypeError: The signature of target buffer exprected an independent Var, but got " + << old_expr << "."; + const Var& var = Downcast(old_expr); + auto it = var_map->find(var); + if (it != var_map->end()) { + CHECK(equal(it->second, new_expr)) + << "ValueError: The assigned value of var " << var << " mismatched. " << it->second + << " vs. " << new_expr << "."; + } else { + (*var_map)[var] = new_expr; + } + } + }; + + // Check buffer dimensions + CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size()) + << "ValueError: The buffer dimensions mismatched" << buf_to_specialize->shape.size() + << " vs. " << specific_buf->shape.size() << "."; + + CHECK(specific_buf->strides.size() == buf_to_specialize->strides.size()) + << "ValueError: The buffer strides dimensions mismatched" << buf_to_specialize->strides.size() + << " vs. " << specific_buf->strides.size() << "."; + + // Updating var mapping using specific_expr + for (size_t i = 0; i < specific_buf->shape.size(); ++i) { + build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]); + } + for (size_t i = 0; i < specific_buf->strides.size(); ++i) { + build_var_mapping(specific_buf->strides[i], buf_to_specialize->strides[i]); + } + build_var_mapping(specific_buf->elem_offset, buf_to_specialize->elem_offset); + + // Check data_alignment and offset_factor. + // These two signatures are int, so we do not need map them. + CHECK_EQ(specific_buf->data_alignment, buf_to_specialize->data_alignment) + << "ValueError: The buffer data_alignment mismatched" << buf_to_specialize->data_alignment + << " vs. " << specific_buf->data_alignment << "."; + + CHECK_EQ(specific_buf->offset_factor, buf_to_specialize->offset_factor) + << "ValueError: The buffer offset_factor mismatched" << buf_to_specialize->offset_factor + << " vs. " << specific_buf->offset_factor << "."; +} + +/*! + * \brief Update Specialize var map with parameter value. + * \param func The function to be specialized. + * \param param The given function parameter + * \param specific_expr The parameter value. + * \param var_map The var mapping to be updated. + */ +void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimExpr& specific_expr, + VarMap* var_map) { + // check param is in PrimFunc's parameters + CHECK(IsParam(func, param)) << "ValueError: Specialize expects param to be in PrimFunc's params"; + // specialize a param not in buffer_map + CHECK_EQ(func->buffer_map.count(param), 0) + << "ValueError: Specialize expects param to not be in PrimFunc's buffer_map"; + // build var mapping using specific_expr + (*var_map)[param] = specific_expr; +} + +/**************** Implementation ****************/ + +PrimFunc Specialize(PrimFunc func, const Map& param_map) { + VarMap var_map; + for (const auto& kv : param_map) { + const Var& param = kv.first; + const ObjectRef& instance = kv.second; + if (instance->IsInstance()) { + UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); + } else if (instance->IsInstance()) { + UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); + } else { + LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got " + << instance->GetTypeKey(); + } + } + return PrimFuncSpecializer::Specialize(func, std::move(var_map)); +} + +/**************** FFI ****************/ + +TVM_REGISTER_GLOBAL("tir.Specialize").set_body_typed(Specialize); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_specialize.py b/tests/python/unittest/test_tir_specialize.py new file mode 100644 index 000000000000..2e9f1110732a --- /dev/null +++ b/tests/python/unittest/test_tir_specialize.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring, missing-module-docstring + +import tvm +from tvm import tir +from tvm.script import ty + + +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle, n: ty.int32) -> None: + m = tir.var("int32") + A = tir.match_buffer(a, [m, n]) + B = tir.match_buffer(b, [m, n]) + C = tir.match_buffer(c, [m, m]) + + with tir.block([m, m, tir.reduce_axis(0, n)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def matmul_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def matmul_m_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + m = tir.var("int32") + A = tir.match_buffer(a, [m, 128]) + B = tir.match_buffer(b, [m, 128]) + C = tir.match_buffer(c, [m, m]) + + with tir.block([m, m, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def matmul_m_8x(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + x = tir.var("int32") + m = tir.var("int32") + A = tir.match_buffer(a, [m, x * 8]) + B = tir.match_buffer(b, [m, x * 8]) + C = tir.match_buffer(c, [m, m]) + + with tir.block([m, m, tir.reduce_axis(0, x * 8)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def element_wise(a: ty.handle, c: ty.handle) -> None: + m = tir.var("int32") + n = tir.var("int32") + A = tir.match_buffer(a, (m, n), "float32") + C = tir.match_buffer(c, (m, n), "float32") + + B = tir.alloc_buffer((m, n), "float32") + + with tir.block([m, n], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + with tir.block([m, n], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def element_wise_128_64(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 64), "float32") + C = tir.match_buffer(c, (128, 64), "float32") + B = tir.alloc_buffer((128, 64), "float32") + + with tir.block([128, 64], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + with tir.block([128, 64], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def element_wise_128_n(a: ty.handle, c: ty.handle) -> None: + n = tir.var("int32") + A = tir.match_buffer(a, (128, n), "float32") + C = tir.match_buffer(c, (128, n), "float32") + B = tir.alloc_buffer((128, n), "float32") + + with tir.block([128, n], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + with tir.block([128, n], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def mem_copy( + a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32, q: ty.int32 +) -> None: + A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=q) + B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=q) + + with tir.block([m, n], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + +@tvm.script.tir +def mem_copy_16_16_8_4(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32", strides=[8, 1], elem_offset=4) + B = tir.match_buffer(b, (16, 16), "float32", strides=[8, 1], elem_offset=4) + + with tir.block([16, 16], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + +@tvm.script.tir +def mem_copy_m_n_p_n(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32) -> None: + A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=n) + B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=n) + + with tir.block([m, n], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + +def test_specialize_nothing(): + func = matmul.specialize({}) + assert func.same_as(matmul) # Pointer the same + + +def test_specialize_matmul(): + a, _, _, n = matmul.params + # fully specialized + func = matmul.specialize({a: tir.decl_buffer((128, 128))}) + tvm.ir.assert_structural_equal(func, matmul_128) + # partially specialized + func = matmul.specialize({n: 128}) + tvm.ir.assert_structural_equal(func, matmul_m_128) + # symbolic specialized + func = matmul.specialize({n: tir.Var("x", "int32") * 8}) + tvm.ir.assert_structural_equal(func, matmul_m_8x) + + +def test_specialize_elemwise(): + a, c = element_wise.params + C = element_wise.buffer_map[c] + # fully specialized + func = element_wise.specialize({a: tir.decl_buffer((128, 64))}) + tvm.ir.assert_structural_equal(func, element_wise_128_64) + # partially specialized + func = element_wise.specialize({c: tir.decl_buffer((128, C.shape[1]))}) + tvm.ir.assert_structural_equal(func, element_wise_128_n) + + +def test_specialize_mem_copy(): + a, _, m, n, p, q = mem_copy.params + # fully specialized + func = mem_copy.specialize({a: tir.decl_buffer((16, 16), strides=[8, 1], elem_offset=4)}) + tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4) + func = mem_copy.specialize({n: 16, m: 16, p: 8, q: 4}) + tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4) + # partially specialized + func = mem_copy.specialize({q: n}) + tvm.ir.assert_structural_equal(func, mem_copy_m_n_p_n) + + +def test_specialize_recursive_load(): + # TODO(Siyuan): add recursive Load testcase, e.g. A[C[i]] + pass + + +if __name__ == "__main__": + test_specialize_nothing() + test_specialize_matmul() + test_specialize_elemwise() + test_specialize_mem_copy() + test_specialize_recursive_load()