diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h new file mode 100644 index 000000000000..00f8cf6ee9f0 --- /dev/null +++ b/include/tvm/arith/iter_affine_map.h @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/iter_affine_map.h + * \brief Iterator quasi-affine mapping patterns. + * + * This file defines a collection of mapping patterns + * maps a collection of independent iterators to another + * collection of independent iterators. + * + * There are two main kinds of mapping patterns: + * + * - Fuse: fuse a collection of iterators into a single one + * + * domain(x0) = [0, 4), domain(x1) = [0, 3), domain(x2) = [0, 2) + * fuse(x0, x1, x2): y = x2 * 12 + x1 * 4 + x0 + * domain(y) = [0, 24) + * + * - Split: split an iterator into multiple ones + * + * domain(x) = [0, 24) + * split(x, 3, 12): [y0, y1, y2] = [x % 3, (x % 12) / 3, x / 12] + * domain(y0) = [0, 3), domain(y1) = [0, 4), domain(y2) = [0, 2) + * + * We use the name "(quasi)affine" to be consistent with + * the terminology used in the polyhedral compilation. + * Notably, fuse is an affine transformation, + * while split corresponds to additional floordiv/mod operations + * that can appear in quasi-affine transformations. + */ +#ifndef TVM_ARITH_ITER_AFFINE_MAP_H_ +#define TVM_ARITH_ITER_AFFINE_MAP_H_ + +#include + +namespace tvm { +namespace arith { + +/*! + * \brief Base class of all iter map expressions. + * + * An IterMapExpr is a special expression to store + * the result of IterMapDetection. + * It should not appear in a legal TIR PrimFunc. + */ +class IterMapExprNode : public PrimExprNode { + public: + // overrides + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "arith.IterMapExpr"; + static constexpr const uint32_t _type_child_slots = 3; + TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode); +}; + +/*! + * \brief Managed reference to IterMapExprNode. + * \sa IterMapExprNode + */ +class IterMapExpr : public PrimExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(IterMapExpr, PrimExpr, IterMapExprNode); +}; + +/*! + * \brief Mark the source as an iterator in [0, extent). + * + * IterMark is used to mark source expression as a valid + * iterator to make future analysis easy. + */ +class IterMarkNode : public Object { + public: + /*! + * \brief The source expression, can either be + * a IterSumExpr or a Var. + */ + PrimExpr source; + /*! + * \brief The extent of the iteration. + */ + PrimExpr extent; + + // overrides + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("source", &source); + v->Visit("extent", &extent); + } + + bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(source, other->source) && equal(extent, other->extent); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(source); + hash_reduce(extent); + } + + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const char* _type_key = "arith.IterMark"; + TVM_DECLARE_FINAL_OBJECT_INFO(IterMarkNode, Object); +}; + +/*! + * \brief Managed reference to IterMarkExprNode. + * \sa IterMarkExprNode + */ +class IterMark : public ObjectRef { + public: + /*! + * \brief constructor. + * \param source The source expression. + * \param extent The extent of the iterator. + */ + TVM_DLL IterMark(PrimExpr source, PrimExpr extent); + + TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode); +}; + +/*! + * \brief Split of an iterator. + * + * result = floormod(floordiv(source, lower_factor), extent) * scale + */ +class IterSplitExprNode : public IterMapExprNode { + public: + /*! \brief The source marked iterator. */ + IterMark source; + /*! \brief The lower factor to split the source. */ + PrimExpr lower_factor; + /*! \brief The extent of the split. */ + PrimExpr extent; + /*! \brief Additional scale. */ + PrimExpr scale; + + // overrides + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("source", &source); + v->Visit("lower_factor", &lower_factor); + v->Visit("extent", &extent); + v->Visit("scale", &scale); + } + + bool SEqualReduce(const IterSplitExprNode* other, SEqualReducer equal) const { + return equal(source, other->source) && equal(lower_factor, other->lower_factor) && + equal(extent, other->extent) && equal(scale, other->scale); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(source); + hash_reduce(lower_factor); + hash_reduce(extent); + hash_reduce(scale); + } + + static constexpr const char* _type_key = "arith.IterSplitExpr"; + TVM_DECLARE_FINAL_OBJECT_INFO(IterSplitExprNode, IterMapExprNode); +}; + +/*! + * \brief Managed reference to IterSplitExprNode. + * \sa IterSplitExprNode + */ +class IterSplitExpr : public IterMapExpr { + public: + /*! + * \brief constructor from just source. + * \param source The source expression. + */ + TVM_DLL explicit IterSplitExpr(IterMark source); + /*! + * \brief constructor + * \param source The source expression. + * \param lower_factor The lower factor to split the source. + * \param extent The extent of the split. + * \param scale The additional scaling factor. + */ + TVM_DLL explicit IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent, + PrimExpr scale); + + TVM_DEFINE_OBJECT_REF_METHODS(IterSplitExpr, IterMapExpr, IterSplitExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSplitExprNode); +}; + +/*! + * \brief Fuse multiple iterators by summing them with scaling. + * + * result = sum(args) + base + */ +class IterSumExprNode : public IterMapExprNode { + public: + /*! \brief The args to the sum. */ + Array args; + /*! \brief The base offset. */ + PrimExpr base; + + // overrides + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("args", &args); + v->Visit("base", &base); + } + + bool SEqualReduce(const IterSumExprNode* other, SEqualReducer equal) const { + return equal(args, other->args) && equal(base, other->base); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(args); + hash_reduce(base); + } + + static constexpr const char* _type_key = "arith.IterSumExpr"; + TVM_DECLARE_FINAL_OBJECT_INFO(IterSumExprNode, IterMapExprNode); +}; + +/*! + * \brief Managed reference to IterSumExprNode. + * \sa IterSumExprNode + */ +class IterSumExpr : public IterMapExpr { + public: + /*! + * \brief constructor. + * \param args The args to the sum. + * \param base The base offset. + */ + TVM_DLL IterSumExpr(Array args, PrimExpr base); + + TVM_DEFINE_OBJECT_REF_METHODS(IterSumExpr, IterMapExpr, IterSumExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode); +}; + +/*! + * \brief Detect if indices can be written as + * + * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n] + * + * Here y = some-quasi-affine-iter-map(input_iters) + * and c are symbolic constants. + * + * We also requires that y_i and y_j to be independent for i != j. + * + * For returned value rv, the following is always true: + * - rv[i]->args.size() <=1: only one iterator per element. + * + * \param indices The indices to detect pattern for. + * \param input_iters Map from variable to iterator's range. + * \param analyzer Analyzer used to get context information. + * + * \return The detected pattern if a match exists, + * otherwise return an empty array. + */ +Array DetectIterMap(const Array& indices, const Map& input_iters, + arith::Analyzer* analyzer); + +} // namespace arith +} // namespace tvm +#endif // TVM_ARITH_ITER_AFFINE_MAP_H_ diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index e5af52938f5c..77ec869a171e 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -21,3 +21,5 @@ from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound from .int_solver import solve_linear_equations, solve_linear_inequalities +from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr +from .iter_affine_map import detect_iter_map diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py new file mode 100644 index 000000000000..123d9b85480a --- /dev/null +++ b/python/tvm/arith/iter_affine_map.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Iterator (quasi)affine mapping patterns.""" +import tvm._ffi +from tvm.runtime import Object +from tvm.ir import PrimExpr +from . import _ffi_api + + +class IterMapExpr(PrimExpr): + """Base class of all IterMap expressions.""" + + +@tvm._ffi.register_object("arith.IterMark") +class IterMark(Object): + """Mark the source as an iterator in [0, extent). + + Parameters + ---------- + source : PrimExpr. + The source expression. + + extent : PrimExpr + The extent of the iterator. + """ + + def __init__(self, source, extent): + self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent) + + +@tvm._ffi.register_object("arith.IterSplitExpr") +class IterSplitExpr(IterMapExpr): + """Split of an iterator. + + result = floormod(floordiv(source, lower_factor), extent) * scale + + Parameters + ---------- + source : IterMark + The source marked iterator. + + lower_factor : PrimExpr + The lower factor to split the domain. + + extent : PrimExpr + The extent of the split. + + scale : PrimExpr + Additional scale to the split. + """ + + def __init__(self, source, lower_factor, extent, scale): + self.__init_handle_by_constructor__( + _ffi_api.IterSplitExpr, source, lower_factor, extent, scale + ) + + +@tvm._ffi.register_object("arith.IterSumExpr") +class IterSumExpr(IterMapExpr): + """Fuse multiple iterators by summing them with scaling. + + result = sum(args) + base + + Parameters + ---------- + args : List[IterSplitExpr] + The input to the sum expression. + + base : PrimExpr + The base offset. + """ + + def __init__(self, args, base): + self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) + + +def detect_iter_map(indices, input_iters): + """Detect if indices can be written mapped iters from input_iters. + + Parameters + ---------- + indices : List[PrimExpr] + The input indices. + + input_iters : Map[Var, Range] + The domain of each input iterators. + + Returns + ------- + results : List[IterSumExpr] + The iter map matching result. + Empty array if no match can be found. + """ + return _ffi_api.DetectIterMap(indices, input_iters) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc new file mode 100644 index 000000000000..7afa75a7efb0 --- /dev/null +++ b/src/arith/iter_affine_map.cc @@ -0,0 +1,717 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/arith/iter_affine_map.cc + */ +#include +#include +#include +#include +#include +#include + +#include "../support/util.h" +#include "const_fold.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +IterMark::IterMark(PrimExpr source, PrimExpr extent) { + auto n = make_object(); + n->source = std::move(source); + n->extent = std::move(extent); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) { + return IterMark(source, extent); +}); + +TVM_REGISTER_NODE_TYPE(IterMarkNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IterMark(" << op->source << ", extent=" << op->extent; + }); + +IterSplitExpr::IterSplitExpr(IterMark source) { + auto n = make_object(); + auto one = make_const(source->source->dtype, 1); + n->dtype = source->source->dtype; + n->source = std::move(source); + n->extent = n->source->extent; + n->lower_factor = one; + n->scale = one; + data_ = std::move(n); +} + +IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent, + PrimExpr scale) { + auto n = make_object(); + n->dtype = source->source->dtype; + n->source = std::move(source); + n->lower_factor = std::move(lower_factor); + n->extent = std::move(extent); + n->scale = std::move(scale); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("arith.IterSplitExpr") + .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) { + return IterSplitExpr(source, lower_factor, extent, scale); + }); + +TVM_REGISTER_NODE_TYPE(IterSplitExprNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IterSplit(" << op->source << ", lower_factor=" << op->lower_factor + << ", extent=" << op->extent << ", scale=" << op->scale; + }); + +IterSumExpr::IterSumExpr(Array args, PrimExpr base) { + auto n = make_object(); + n->dtype = base->dtype; + n->args = std::move(args); + n->base = std::move(base); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("arith.IterSumExpr") + .set_body_typed([](Array args, PrimExpr base) { + return IterSumExpr(args, base); + }); + +TVM_REGISTER_NODE_TYPE(IterSumExprNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IterSum(" << op->args << ", " << op->base << ")"; + }); + +/*! + * \brief Collector that collects + * the outgoing split reference of each IterMark. + * + * These out-going splits can then be used to + * check if the iterators are independent. + */ +class IterMarkSplitCollector { + public: + // mark all IterMarks that are visited. + std::unordered_set visited_; + // each iter mark to its outgoing splits that are referenced. + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + mark2splits_; + /*! + * \brief Collect all mark2splits recursively from indices. + * \param indices The iterator of interest. + */ + void Collect(const Array& indices) { + for (IterSumExpr sum_expr : indices) { + for (IterSplitExpr split : sum_expr->args) { + this->CollectInternal(split->source); + mark2splits_[split->source].push_back(split); + } + } + } + + void CollectInternal(const IterMark& mark) { + if (visited_.count(mark)) return; + visited_.insert(mark); + if (auto* op = mark->source.as()) { + for (IterSplitExpr split : op->args) { + this->CollectInternal(split->source); + mark2splits_[split->source].push_back(split); + } + } + } +}; + +// Rewriter to rewrite PrimExpr to IterMapExpr +// when possible +class IterMapRewriter : public ExprMutator { + public: + using Parent = ExprMutator; + + explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters) + : analyzer_(analyzer) { + for (auto kv : input_iters) { + const auto& vrng = kv.second; + if (is_zero(vrng->min)) { + IterMark mark(kv.first, vrng->extent); + var_map_[kv.first] = IterSplitExpr(mark); + input_marks_.push_back(mark); + } else { + IterMark mark(kv.first - vrng->min, vrng->extent); + auto sum_expr = ToIterSumExpr(IterSplitExpr(mark)); + sum_expr.CopyOnWrite()->base = vrng->min; + var_map_[kv.first] = sum_expr; + input_marks_.push_back(mark); + } + } + } + + size_t unresolved_count() const { return unresolved_count_; } + + IterSumExpr Rewrite(PrimExpr expr) { + return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); + } + + bool CheckBijective(const Array& indices) { + // This function checks two conditions: + // - C0: Each iter mark should be fully covered by non-overlapping splits. + // - C1: All of the input iterators are used. + // + // Example: given x in [0, 8) y in [0, 6) + // - indices = [x, x+1, y] won't pass because x and x+1 contribute + // two splits that overlaps with each other. + // - indices = [x / 4, x % 4, y] will pass because x / 4 and x % 4 + // contribute two non-overlapping splits that covers x. + // - indices = [x / 4, x % 4] won't pass because y is not used. + // + IterMarkSplitCollector collector; + // We can check that for each iter mark: + // All the splits that refers to the itermark covers its extent. + // The splits do not overlap with each other. + collector.Collect(indices); + for (IterMark mark : collector.visited_) { + if (TryNormalizeSplits(mark, collector.mark2splits_[mark]).size() == 0) return false; + } + // all input marks must be visited + for (auto mark : input_marks_) { + if (collector.visited_.count(mark) == 0) return false; + } + return true; + } + + // override the original mutate function. + PrimExpr VisitExpr(const PrimExpr& input_expr) final { + auto expr = ExprMutator::VisitExpr(input_expr); + if (expr->IsInstance()) { + ++unresolved_count_; + } + return expr; + } + + // Normal mutation without normalization. + PrimExpr DirectMutate(PrimExpr expr) { return ExprMutator::VisitExpr(expr); } + + PrimExpr VisitExpr_(const VarNode* op) final; + PrimExpr VisitExpr_(const AddNode* op) final; + PrimExpr VisitExpr_(const SubNode* op) final; + PrimExpr VisitExpr_(const MulNode* op) final; + PrimExpr VisitExpr_(const FloorDivNode* op) final; + PrimExpr VisitExpr_(const FloorModNode* op) final; + + private: + // temp hash for de-duplication purposes. + struct IterSumHash { + size_t operator()(const IterSumExpr& value) const { + // for now only hash on source index. + size_t hash = value->args.size(); + for (size_t i = 0; i < value->args.size(); ++i) { + hash = support::HashCombine(hash, std::hash()(value->args[i]->source.get())); + } + return hash; + } + }; + + struct IterSumEqual { + bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const { + tir::ExprDeepEqual equal; + if (lhs->args.size() != rhs->args.size()) return false; + if (!equal(lhs->base, rhs->base)) return false; + for (size_t i = 0; i < lhs->args.size(); ++i) { + auto lvalue = lhs->args[i]; + auto rvalue = lhs->args[i]; + if (!lvalue->source.same_as(rvalue->source)) return false; + if (!equal(lvalue->lower_factor, rvalue->lower_factor)) return false; + if (!equal(lvalue->scale, rvalue->scale)) return false; + if (!equal(lvalue->extent, rvalue->extent)) return false; + } + return true; + } + }; + + // Internal analyzer + Analyzer* analyzer_; + // Counter to keep track of unresolved cases. + int unresolved_count_{0}; + // The var map + std::unordered_map var_map_; + // input iter marks + std::vector input_marks_; + // The canonical map for sum + std::unordered_map sum_fuse_map_; + + /*! + * \brief Verify that splits fully covers mark in a non-overlapping fashion. + * If verification passes, return splits from outermost to inner most order. + * If not, return an empty array + * \param mark The iterator of interest. + * \param splits The splits to be verified. + * \return The normalized splits. + */ + Array TryNormalizeSplits(const IterMark& mark, + const std::vector& splits) { + std::vector used(splits.size(), false); + std::vector iters; + PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1); + + for (size_t i = 0; i < splits.size(); ++i) { + size_t j = 0; + for (; j < splits.size(); ++j) { + if (used[j]) continue; + if (!used[j] && CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) break; + } + if (j == splits.size()) { + return Array(); + } + used[j] = true; + iters.push_back(splits[j]); + expected_lower_factor *= splits[j]->extent; + } + if (!CanProveEqual(expected_lower_factor, mark->extent)) return Array(); + return Array(iters.rbegin(), iters.rend()); + } + + /*! + * \brief Normalize expr to an iterator + offset. + * \param expr The input expression. + * \return The Normalized expression. + */ + IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) { + if (expr->args.size() <= 1) return expr; + PrimExpr base = expr->base; + expr.CopyOnWrite()->base = make_zero(expr->dtype); + auto opt = TryFuseIters(expr); + expr.CopyOnWrite()->base = base; + if (opt) { + expr.CopyOnWrite()->args = Array({opt.value()}); + return expr; + } else { + ++unresolved_count_; + return expr; + } + } + + bool CanProveEqual(PrimExpr lhs, PrimExpr rhs) { + const auto* clhs = lhs.as(); + const auto* crhs = rhs.as(); + if (clhs && crhs) return clhs->value == crhs->value; + return analyzer_->CanProve(lhs - rhs == 0); + } + + /*! + * \brief Create a IterSumExpr from expr. + * \param expr The input expr. + * \return The transformed IterSumExpr. + */ + IterSumExpr ToIterSumExpr(PrimExpr expr) { + if (const auto* op = expr.as()) { + return GetRef(op); + } else if (const auto* op = expr.as()) { + return IterSumExpr({GetRef(op)}, make_zero(expr->dtype)); + } else { + CHECK(!expr->IsInstance()); + return IterSumExpr({}, expr); + } + } + + // Try to normalize IterSum into a fused IterMark + // return a corresponding splitexpr if needed. + Optional TryFuseIters(IterSumExpr expr) { + if (!is_zero(expr->base)) return NullOpt; + if (expr->args.size() == 1) return expr->args[0]; + // select the iterators in order + std::vector visited(expr->args.size(), false); + std::vector iters; + iters.reserve(expr->args.size()); + // canonicalize the expression + // check if it can be remapped into a fused pattern. + PrimExpr expected_scale = make_const(expr->base->dtype, 1); + for (size_t i = 0; i < expr->args.size(); ++i) { + size_t j = 0; + for (; j < expr->args.size(); ++j) { + if (!visited[j] && CanProveEqual(expr->args[j]->scale, expected_scale)) break; + } + if (j == expr->args.size()) { + return NullOpt; + } + visited[j] = true; + iters.push_back(expr->args[j]); + expected_scale *= expr->args[j]->extent; + } + // update the iterator to use the canonicalized form + expr.CopyOnWrite()->args = Array(iters.rbegin(), iters.rend()); + auto it = sum_fuse_map_.find(expr); + if (it != sum_fuse_map_.end()) return it->second; + auto mark = IterMark(expr, expected_scale); + IterSplitExpr split(mark); + sum_fuse_map_[expr] = split; + return split; + } + + bool CanProveDivisible(PrimExpr lhs, PrimExpr rhs) { + const auto* clhs = lhs.as(); + const auto* crhs = rhs.as(); + if (clhs && crhs) return clhs->value % crhs->value == 0; + return analyzer_->CanProve(floormod(lhs, rhs) == 0); + } + + PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs); + PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs); + + static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) { + tir::ExprDeepEqual equal; + for (size_t i = 0; i < lhs->args.size(); ++i) { + IterSplitExpr lvalue = lhs->args[i]; + if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, rhs->lower_factor) && + equal(lvalue->extent, rhs->extent)) { + if (sign > 0) { + rhs.CopyOnWrite()->scale = lvalue->scale + rhs->scale; + } else { + rhs.CopyOnWrite()->scale = lvalue->scale - rhs->scale; + } + lhs->args.Set(i, rhs); + return; + } + } + if (sign > 0) { + lhs->args.push_back(rhs); + } else { + rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale; + lhs->args.push_back(rhs); + } + } + + static void AddToLhs(IterSumExprNode* lhs, IterSumExpr rhs, int sign) { + for (size_t i = 0; i < rhs->args.size(); ++i) { + AddToLhs(lhs, rhs->args[i], sign); + } + if (sign > 0) { + lhs->base += rhs->base; + } else { + lhs->base -= rhs->base; + } + } + + static void MulToLhs(IterSumExprNode* lhs, PrimExpr rhs) { + for (size_t i = 0; i < lhs->args.size(); ++i) { + IterSplitExpr lvalue = lhs->args[i]; + lvalue.CopyOnWrite()->scale *= rhs; + lhs->args.Set(i, lvalue); + } + lhs->base *= rhs; + } +}; + +Array DetectIterMap(const Array& indices, const Map& input_iters, + arith::Analyzer* analyzer) { + // Overall detection algorithm is divided into two steps: + // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. + // - Step1: IterIndependenceChecker checks if the iterator are independent. + IterMapRewriter rewriter(analyzer, input_iters); + Array results; + + for (PrimExpr value : indices) { + results.push_back(rewriter.Rewrite(value)); + if (rewriter.unresolved_count() != 0) return Array(); + } + if (!rewriter.CheckBijective(results)) return Array(); + + return results; +} + +TVM_REGISTER_GLOBAL("arith.DetectIterMap") + .set_body_typed([](const Array& indices, const Map& input_iters) { + arith::Analyzer ana; + return DetectIterMap(indices, input_iters, &ana); + }); + +PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { + auto var = GetRef(op); + auto it = var_map_.find(var); + if (it != var_map_.end()) return it->second; + return std::move(var); +} + +PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { + if (!IsIndexType(op->dtype)) { + return Parent::VisitExpr_(op); + } + + PrimExpr a = this->DirectMutate(op->a); + PrimExpr b = this->DirectMutate(op->b); + + // const folding + PrimExpr const_res = TryConstFold(a, b); + if (const_res.defined()) return const_res; + // does not contain iter map. + if (!a->IsInstance() && !b->IsInstance()) { + if (op->a.same_as(a) && op->b.same_as(b)) { + return GetRef(op); + } else { + return Add(a, b); + } + } + + // canonical form simplification. + IterSumExpr ret = ToIterSumExpr(std::move(a)); + + if (!b->IsInstance()) { + ret.CopyOnWrite()->base += b; + } else if (const auto* op = b.as()) { + AddToLhs(ret.CopyOnWrite(), GetRef(op), 1); + } else if (const auto* op = b.as()) { + AddToLhs(ret.CopyOnWrite(), GetRef(op), 1); + } else { + AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), 1); + } + return std::move(ret); +} + +PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) { + if (!IsIndexType(op->dtype)) { + return Parent::VisitExpr_(op); + } + + PrimExpr a = this->DirectMutate(op->a); + PrimExpr b = this->DirectMutate(op->b); + + // const folding + PrimExpr const_res = TryConstFold(a, b); + if (const_res.defined()) return const_res; + + // does not contain iter map. + if (!a->IsInstance() && !b->IsInstance()) { + if (op->a.same_as(a) && op->b.same_as(b)) { + return GetRef(op); + } else { + return Sub(a, b); + } + } + + // canonical form simplification. + IterSumExpr ret = ToIterSumExpr(std::move(a)); + + if (!b->IsInstance()) { + ret.CopyOnWrite()->base -= b; + } else if (const auto* op = b.as()) { + AddToLhs(ret.CopyOnWrite(), GetRef(op), -1); + } else if (const auto* op = b.as()) { + AddToLhs(ret.CopyOnWrite(), GetRef(op), -1); + } else { + AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), -1); + } + return std::move(ret); +} + +PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { + if (!IsIndexType(op->dtype)) { + return Parent::VisitExpr_(op); + } + // normalize + PrimExpr a = this->DirectMutate(op->a); + PrimExpr b = this->DirectMutate(op->b); + + // const folding + PrimExpr const_res = TryConstFold(a, b); + if (const_res.defined()) return const_res; + + // does not contain iter map. + if (!a->IsInstance() && !b->IsInstance()) { + if (op->a.same_as(a) && op->b.same_as(b)) { + return GetRef(op); + } else { + return Mul(a, b); + } + } + + if (a->IsInstance() && b->IsInstance()) { + // cannot multiply two iterators, mark as unresolved. + ++unresolved_count_; + return Mul(a, b); + } + + if (!a->IsInstance()) { + std::swap(a, b); + } + + if (a->IsInstance()) { + IterSumExpr ret = Downcast(std::move(a)); + MulToLhs(ret.CopyOnWrite(), b); + return std::move(ret); + } else { + CHECK(a->IsInstance()); + IterSplitExpr ret = Downcast(std::move(a)); + ret.CopyOnWrite()->scale *= b; + return std::move(ret); + } +} + +PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) { + if (is_one(rhs)) return std::move(lhs); + if (!is_one(lhs->scale)) { + if (CanProveDivisible(lhs->scale, rhs)) { + lhs.CopyOnWrite()->scale = floordiv(lhs->scale, rhs); + return std::move(lhs); + } else { + if (CanProveDivisible(rhs, lhs->scale)) { + rhs = floordiv(rhs, lhs->scale); + lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); + } else { + // mark as unresolved. + ++unresolved_count_; + return floordiv(lhs, rhs); + } + } + } + + if (CanProveDivisible(lhs->extent, rhs)) { + auto* ptr_lhs = lhs.CopyOnWrite(); + ptr_lhs->lower_factor *= rhs; + ptr_lhs->extent = analyzer_->Simplify(floordiv(ptr_lhs->extent, rhs)); + return std::move(lhs); + } else { + // mark as unresolved. + ++unresolved_count_; + return floordiv(lhs, rhs); + } +} + +PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { + if (!IsIndexType(op->dtype)) { + return Parent::VisitExpr_(op); + } + + PrimExpr a = this->DirectMutate(op->a); + PrimExpr b = this->DirectMutate(op->b); + + // const folding + PrimExpr const_res = TryConstFold(a, b); + if (const_res.defined()) return const_res; + + // does not contain iter map. + if (!a->IsInstance() && !b->IsInstance()) { + if (op->a.same_as(a) && op->b.same_as(b)) { + return GetRef(op); + } else { + return FloorDiv(a, b); + } + } + + if (b->IsInstance()) { + // cannot divide an iterator, mark as unresolved. + ++unresolved_count_; + return FloorDiv(a, b); + } + + if (a->IsInstance()) { + IterSumExpr ret = Downcast(std::move(a)); + if (auto opt = TryFuseIters(ret)) { + return SplitFloorDivConst(opt.value(), b); + } else { + ++unresolved_count_; + return FloorDiv(a, b); + } + } else { + CHECK(a->IsInstance()); + IterSplitExpr ret = Downcast(std::move(a)); + return SplitFloorDivConst(ret, b); + } +} + +PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) { + if (is_one(rhs)) return make_zero(lhs->dtype); + if (!is_one(lhs->scale)) { + if (CanProveDivisible(lhs->scale, rhs)) { + return make_zero(lhs->dtype); + } else { + if (CanProveDivisible(rhs, lhs->scale)) { + rhs = floormod(rhs, lhs->scale); + } else { + // mark as unresolved. + ++unresolved_count_; + return floormod(lhs, rhs); + } + } + } + + if (CanProveDivisible(lhs->extent, rhs)) { + lhs.CopyOnWrite()->extent = rhs; + return std::move(lhs); + } else { + // mark as unresolved. + ++unresolved_count_; + return floormod(lhs, rhs); + } +} + +PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { + if (!IsIndexType(op->dtype)) { + return Parent::VisitExpr_(op); + } + + PrimExpr a = this->DirectMutate(op->a); + PrimExpr b = this->DirectMutate(op->b); + + // const folding + PrimExpr const_res = TryConstFold(a, b); + if (const_res.defined()) return const_res; + + // does not contain iter map. + if (!a->IsInstance() && !b->IsInstance()) { + if (op->a.same_as(a) && op->b.same_as(b)) { + return GetRef(op); + } else { + return FloorMod(a, b); + } + } + + if (b->IsInstance()) { + // cannot mod an iterator, mark as unresolved. + ++unresolved_count_; + return FloorMod(a, b); + } + + if (a->IsInstance()) { + IterSumExpr ret = Downcast(std::move(a)); + if (auto opt = TryFuseIters(ret)) { + return SplitFloorModConst(opt.value(), b); + } else { + ++unresolved_count_; + return FloorMod(a, b); + } + } else { + CHECK(a->IsInstance()); + IterSplitExpr ret = Downcast(std::move(a)); + return SplitFloorModConst(ret, b); + } +} + +} // namespace arith +} // namespace tvm diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index c237edc493a6..cb8ef01e7369 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -882,6 +882,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x)); + TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y)); + // try modular analysis if (floormod(x, c1).Match(ret)) { ModularSet mod = analyzer_->modular_set(x.Eval()); diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index d21cb1f2d9b3..1122b8e1ee40 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -28,6 +28,8 @@ #include #include +#include "../support/util.h" + namespace tvm { // Define the dispatch functio here since primary user is in this file. @@ -163,7 +165,7 @@ class VarCountingSHashHandler : public SHashReducer::Handler { // combine in the reverse order of the stack. size_t reduced_hash = task.reduced_hash; for (size_t i = result_stack_.size(); i != stack_begin; --i) { - reduced_hash = HashCombine(reduced_hash, result_stack_[i - 1]); + reduced_hash = support::HashCombine(reduced_hash, result_stack_[i - 1]); } result_stack_.resize(stack_begin); return reduced_hash; @@ -186,8 +188,8 @@ class VarCountingSHashHandler : public SHashReducer::Handler { // Append the graph node counter to the hash // so that we can distinguish DAG from trees. if (entry.graph_node_hash) { - entry.reduced_hash = - HashCombine(entry.reduced_hash, std::hash()(graph_node_counter_++)); + entry.reduced_hash = support::HashCombine(entry.reduced_hash, + std::hash()(graph_node_counter_++)); } hash_memo_[entry.object] = entry.reduced_hash; } @@ -229,16 +231,6 @@ class VarCountingSHashHandler : public SHashReducer::Handler { vtable_->SHashReduce(object.get(), SHashReducer(this, map_free_vars)); } - /*! - * \brief Combine two hash values into a single one. - * \param key The left operand. - * \param value The right operand. - * \return the combined result. - */ - size_t HashCombine(size_t key, size_t value) { - return key ^ (value + 0x9e3779b9 + (key << 6) + (key >> 2)); - } - private: // free var counter. size_t free_var_counter_{0}; diff --git a/src/support/util.h b/src/support/util.h index 859b372bd761..5020df2e2ea7 100644 --- a/src/support/util.h +++ b/src/support/util.h @@ -152,6 +152,16 @@ inline int Execute(std::string cmd, std::string* err_msg) { return 255; } +/*! + * \brief Combine two hash values into a single one. + * \param key The left operand. + * \param value The right operand. + * \return the combined result. + */ +inline size_t HashCombine(size_t key, size_t value) { + return key ^ (value + 0x9e3779b9 + (key << 6) + (key >> 2)); +} + } // namespace support } // namespace tvm #endif // TVM_SUPPORT_UTIL_H_ diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py new file mode 100644 index 000000000000..9fb098831a71 --- /dev/null +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import te + + +def ifuse(inputs): + """Fuse iterators""" + value, extent = 0, 1 + for i, ext in inputs: + value = value * ext + i + extent = extent * ext + return (value, extent) + + +def isplit(axis, factor): + """Split iterators""" + fld = tvm.tir.floordiv + flm = tvm.tir.floormod + return [ + (fld(axis[0], factor), fld(axis[1] + (factor - 1), factor)), + (flm(axis[0], factor), factor), + ] + + +def var_dom(iters): + """Get domains of iterators""" + return {var: tvm.ir.Range(0, ext) for var, ext in iters} + + +def assert_iter_sum_pattern(sum_expr, extent, base, scale=1): + """Check the sum expr have the right pattern.""" + assert isinstance(sum_expr, tvm.arith.IterSumExpr) + if extent == 1: + assert len(sum_expr.args) == 0 + else: + assert len(sum_expr.args) == 1 + tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent) + tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale) + tvm.testing.assert_prim_expr_equal(sum_expr.base, base) + + +def test_trivial(): + x = tvm.tir.Var("x", "int32"), 3 + y = tvm.tir.Var("y", "int32"), 4 + + res = tvm.arith.detect_iter_map([x[0], y[0], 3], var_dom([x, y])) + + assert len(res) == 3 + assert_iter_sum_pattern(res[0], 3, 0) + assert_iter_sum_pattern(res[1], 4, 0) + assert_iter_sum_pattern(res[2], 1, 3) + + res = tvm.arith.detect_iter_map([x[0], 3], var_dom([x, y])) + assert len(res) == 0 + + # not independent + res = tvm.arith.detect_iter_map([x[0], x[0], 3], var_dom([x, y])) + assert len(res) == 0 + + +def test_fuse(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + c = tvm.tir.SizeVar("c", "int32") + + res = tvm.arith.detect_iter_map([y * 3 + 1 + c + x], var_dom([(x, 3), (y, 4)])) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 12, 1 + c) + + res = tvm.arith.detect_iter_map([ifuse([(x, 3), (y, 4)])[0]], var_dom([(x, 3), (y, 4)])) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 12, 0) + + # fuse with symbolic factor + res = tvm.arith.detect_iter_map([(y + 1) * c + x], var_dom([(x, c), (y, 4)])) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 4 * c, c) + + # duplication + res = tvm.arith.detect_iter_map([y * 3 + x, y], var_dom([(x, 3), (y, 4)])) + assert len(res) == 0 + + # duplication 2 + res = tvm.arith.detect_iter_map([y, x + 1, y], var_dom([(x, 3), (y, 4)])) + assert len(res) == 0 + + # factor mismatch + res = tvm.arith.detect_iter_map([y * 4 + x], var_dom([(x, 3), (y, 4)])) + assert len(res) == 0 + + +def test_split(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("y", "int32") + c0 = tvm.tir.SizeVar("c0", "int32") + c1 = tvm.tir.SizeVar("c1", "int32") + c2 = tvm.tir.SizeVar("c1", "int32") + fld = tvm.tir.floordiv + flm = tvm.tir.floormod + + res = tvm.arith.detect_iter_map([fld(x, 3), flm(x, 3) * 2 + c1], var_dom([(x, 24)])) + + assert len(res) == 2 + assert_iter_sum_pattern(res[0], 8, 0) + assert_iter_sum_pattern(res[1], 3, c1, 2) + + res = tvm.arith.detect_iter_map([fld(x, 6), fld(flm(x, 6), 2), flm(x, 2)], var_dom([(x, 24)])) + + assert len(res) == 3 + assert_iter_sum_pattern(res[0], 4, 0) + assert_iter_sum_pattern(res[1], 3, 0) + assert_iter_sum_pattern(res[2], 2, 0) + + # simple symbolic bound + # TODO(tvm-team) improve symbolic divisible check to enable + # more complicated symbolic bound + res = tvm.arith.detect_iter_map([fld(x, c0), flm(x, c0)], var_dom([(x, c1 * c0)])) + + assert len(res) == 2 + assert_iter_sum_pattern(res[0], c1, 0) + assert_iter_sum_pattern(res[1], c0, 0) + + +def test_compound(): + x = tvm.tir.Var("x", "int32"), 10 + y = tvm.tir.Var("y", "int32"), 9 + + xo, xi = isplit(x, 5) + yo, yi = isplit(y, 3) + z = ifuse([yo, xo, yi]) + + res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([x, y])) + + assert len(res) == 2 + assert_iter_sum_pattern(res[0], 18, 0) + assert_iter_sum_pattern(res[1], 5, 0) + # reconstruct the pattern manually + mx = tvm.arith.IterMark(x[0], 10) + my = tvm.arith.IterMark(y[0], 9) + + xoscale = 3 + xiscale = 1 + yoscale = 6 + yiscale = 1 + mxo = tvm.arith.IterSplitExpr(mx, 5, 2, xoscale) + mxi = tvm.arith.IterSplitExpr(mx, 1, 5, xiscale) + myo = tvm.arith.IterSplitExpr(my, 3, 3, yoscale) + myi = tvm.arith.IterSplitExpr(my, 1, 3, yiscale) + + mz = tvm.arith.IterMark(tvm.arith.IterSumExpr([myo, mxo, myi], 0), 18) + sz = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(mz, 1, 18, 1)], 0) + tvm.ir.assert_structural_equal(sz, res[0]) + + +if __name__ == "__main__": + test_split() + test_trivial() + test_fuse() + test_compound()