From b079ab6778a58cc2fa61b349e09dde78b4bf8532 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 2 Sep 2022 14:01:18 -0500 Subject: [PATCH 1/6] [Containers] Add Array::Map Previously, an in-place mutation could be applied to an array using `Array::MutateByApply`, but this couldn't be used for transformations that return a new array, or for transformations that return a new type. The commit adds `Array::Map`, which can map to any `ObjectRef` subclass. For mappings that return the same type, this is done by delegating to `Array::MutateByApply`, to take advantage of the same copy-on-write behavior. --- include/tvm/runtime/container/array.h | 38 +++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index 26f4e545deb7..0589dfae5f83 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -574,6 +574,44 @@ class Array : public ObjectRef { /*! \return The underlying ArrayNode */ ArrayNode* GetArrayNode() const { return static_cast(data_.get()); } + /*! + * \brief Helper function to apply a map function onto the array. + * + * \param fmap The transformation function T -> U. + * + * \tparam F The type of the mutation function. + * + * \tparam U The type of the returned array, inferred from the + * return type of F. If overridden by the user, must be something + * that is convertible from the return type of F. + * + * \note This function performs copy on write optimization. If + * `fmap` returns an object of type `T`, and all elements of the + * array are mapped to themselves, then the returned array will be + * the same as the original, and reference counts of the elements in + * the array will not be incremented. + * + * \return The transformed array. + */ + template > + Array Map(F fmap) const { + if constexpr (std::is_same_v) { + // Special case for outputs of the same type, may be able to use + // MutateByApply's in-place handling to avoid copying data, if + // the mapping function is the identity for all elements. + Array output = *this; + output.MutateByApply(fmap); + return output; + } else { + Array output; + output.reserve(size()); + for (T item : *this) { + output.push_back(fmap(std::move(item))); + } + return output; + } + } + /*! * \brief Helper function to apply fmutate to mutate an array. * \param fmutate The transformation function T -> T. From 655923196e8959ab1e525ab4c1b63e1dc0b8bcbd Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 2 Sep 2022 14:06:53 -0500 Subject: [PATCH 2/6] [Refactor] Use Array::Map where possible With the new `Array::Map` functionality, many places that previously used explicit loops or `tvm::tir::MutateArray` can be cleaned. --- src/ir/type_functor.cc | 9 +---- src/te/operation/create_primfunc.cc | 2 +- src/tir/analysis/device_constraint_utils.cc | 5 +-- src/tir/ir/buffer.cc | 4 +- src/tir/ir/expr.cc | 3 +- src/tir/ir/expr_functor.cc | 14 +++---- src/tir/ir/functor_common.h | 3 +- src/tir/ir/index_map.cc | 5 +-- src/tir/ir/specialize.cc | 19 ++++------ src/tir/ir/stmt_functor.cc | 3 +- .../schedule/primitive/decompose_padding.cc | 15 ++++---- src/tir/schedule/transform.cc | 8 ++-- src/tir/transforms/inject_virtual_thread.cc | 4 +- src/tir/transforms/lower_match_buffer.cc | 8 ++-- src/tir/transforms/renew_defs.cc | 37 +++++++++---------- src/tir/transforms/vectorize_loop.cc | 6 +-- 16 files changed, 61 insertions(+), 84 deletions(-) diff --git a/src/ir/type_functor.cc b/src/ir/type_functor.cc index 51d5d3778c10..36838b62aabc 100644 --- a/src/ir/type_functor.cc +++ b/src/ir/type_functor.cc @@ -97,14 +97,7 @@ Type TypeMutator::VisitType(const Type& t) { Array TypeMutator::MutateArray(Array arr) { // The array will do copy on write // If no changes are made, the original array will be returned. - for (size_t i = 0; i < arr.size(); ++i) { - Type ty = arr[i]; - Type new_ty = VisitType(ty); - if (!ty.same_as(new_ty)) { - arr.Set(i, new_ty); - } - } - return arr; + return arr.Map([this](const Type& ty) { return VisitType(ty); }); } Type TypeMutator::VisitType_(const TypeVarNode* op) { return GetRef(op); } diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 4c1358f42519..fb325684e65b 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -256,7 +256,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, // TensorIR will not allow Tensor data structure if (value->IsInstance()) { const auto array_value = Downcast>(value); - annotations.Set(key, MutateArray(array_value, mutate_attr)); + annotations.Set(key, array_value.Map(mutate_attr)); } else { annotations.Set(key, mutate_attr(value)); } diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index 1309681513a9..32b59ce54b69 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -393,9 +393,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { } template - Array VisitItems(Array items) { - items.MutateByApply([this](const T& item) { return VisitItem(item.get()); }); // copy-on-write - return items; + Array VisitItems(const Array& items) { + return items.Map([this](T item) -> T { return VisitItem(item.get()); }); } Stmt VisitStmt_(const BlockNode* block_node) final { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index cae4109a6026..0dfda954b818 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -461,8 +461,8 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const ICHECK(n != nullptr); arith::Analyzer ana; begins = SimplifyArray(&ana, begins); - Array elem_offset = n->ElemOffset(begins); - elem_offset.MutateByApply([&](const PrimExpr& expr) { return ana.Simplify(expr); }); + Array elem_offset = + n->ElemOffset(begins).Map([&](const PrimExpr& expr) { return ana.Simplify(expr); }); Array strides = n->strides; if (strides.size() == 0) { diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 59db4ea410fd..daae7eaf68f5 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -994,8 +994,7 @@ Array CommReducerNode::operator()(Array a, Array b value_map.Set(lhs[i], a[i]); value_map.Set(rhs[i], b[i]); } - auto ret = this->result; - ret.MutateByApply([&value_map](const PrimExpr& e) { return Substitute(e, value_map); }); + auto ret = this->result.Map([&value_map](const PrimExpr& e) { return Substitute(e, value_map); }); return ret; } diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index c8dc84695b4f..da02e0316f48 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -132,7 +132,7 @@ PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) { PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array indices = MutateArray(op->indices, fmutate); + Array indices = op->indices.Map(fmutate); if (indices.same_as(op->indices)) { return GetRef(op); } else { @@ -142,7 +142,7 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array indices = MutateArray(op->indices, fmutate); + Array indices = op->indices.Map(fmutate); if (indices.same_as(op->indices)) { return GetRef(op); } else { @@ -162,7 +162,7 @@ PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array args = MutateArray(op->args, fmutate); + Array args = op->args.Map(fmutate); if (args.same_as(op->args)) { return GetRef(op); @@ -218,11 +218,11 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { return IterVar(Range::FromMinExtent(min, extent), v->var, v->iter_type, v->thread_tag); } }; - Array axis = MutateArray(op->axis, fitervar); + Array axis = op->axis.Map(fitervar); auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - Array source = MutateArray(op->source, fexpr); - Array init = MutateArray(op->init, fexpr); + Array source = op->source.Map(fexpr); + Array init = op->init.Map(fexpr); PrimExpr condition = this->VisitExpr(op->condition); @@ -285,7 +285,7 @@ PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) { auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; - auto vectors = MutateArray(op->vectors, fexpr); + auto vectors = op->vectors.Map(fexpr); if (vectors.same_as(op->vectors)) { return GetRef(op); } else { diff --git a/src/tir/ir/functor_common.h b/src/tir/ir/functor_common.h index 8b5a361a37c6..b9bb43ca6ba6 100644 --- a/src/tir/ir/functor_common.h +++ b/src/tir/ir/functor_common.h @@ -38,8 +38,7 @@ inline void VisitArray(const Array& arr, F fvisit) { template inline Array MutateArray(Array arr, F fmutate) { - arr.MutateByApply(fmutate); - return arr; + return arr.Map(fmutate); } } // namespace tir diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 0e3c3b2774c8..562c31d8cb20 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -178,9 +178,8 @@ Array IndexMapNode::MapIndices(const Array& indices, analyzer = &local_analyzer; } - Array output = final_indices; - output.MutateByApply( - [&](const PrimExpr& index) { return analyzer->Simplify(Substitute(index, vmap)); }); + Array output = final_indices.Map( + [&](PrimExpr index) { return analyzer->Simplify(Substitute(std::move(index), vmap)); }); return output; } diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 520e3ee03c92..ea68015bc73b 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -115,8 +115,7 @@ class PrimFuncSpecializer : public StmtExprMutator { private: Stmt VisitStmt_(const BlockNode* op) final { // Step.0. Define buffer mappings which is allocated inside the block - Array alloc_buffers = MutateArray( - op->alloc_buffers, + Array alloc_buffers = op->alloc_buffers.Map( std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1)); // Step.1. Recursively visit block body @@ -124,11 +123,9 @@ class PrimFuncSpecializer : public StmtExprMutator { op = stmt.as(); ICHECK(op != nullptr); - Array reads = MutateArray( - op->reads, + Array reads = op->reads.Map( std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); - Array writes = MutateArray( - op->writes, + Array writes = op->writes.Map( std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads)) { @@ -200,10 +197,9 @@ class PrimFuncSpecializer : public StmtExprMutator { private: Buffer MutateBuffer(const Buffer& buffer) { - Array shape = - MutateArray(buffer->shape, [this](const PrimExpr& e) { return VisitExpr(e); }); + Array shape = buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); }); Array strides = - MutateArray(buffer->strides, [this](const PrimExpr& e) { return VisitExpr(e); }); + buffer->strides.Map([this](const PrimExpr& e) { return VisitExpr(e); }); PrimExpr elem_offset = VisitExpr(buffer->elem_offset); @@ -242,9 +238,8 @@ class PrimFuncSpecializer : public StmtExprMutator { BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) { auto it = buffer_map_.find(buffer_region->buffer); - Array region = - MutateArray(buffer_region->region, - std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1)); + Array region = buffer_region->region.Map( + std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1)); if (it == buffer_map_.end() && region.same_as(buffer_region->region)) { return buffer_region; } else { diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index c75eb52f9296..c2e2489cba92 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -183,9 +183,8 @@ class StmtMutator::Internal { return arr; } else { bool allow_cow = false; - Array copy = arr; std::swap(allow_cow, self->allow_copy_on_write_); - copy.MutateByApply(fmutate); + Array copy = arr.Map(fmutate); std::swap(allow_cow, self->allow_copy_on_write_); return copy; } diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index 93fb88e66619..c41760876722 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -212,16 +212,15 @@ static std::pair CreateConstBlock(const BlockRealizeNode* re // create new write region ICHECK_EQ(block->writes.size(), 1U); - BufferRegion write_region = - BufferRegion(block->writes[0]->buffer, - MutateArray(block->writes[0]->region, [rewrite_expr](const Range& r) { - return Range::FromMinExtent(rewrite_expr(r->min), rewrite_expr(r->extent)); - })); + BufferRegion write_region = BufferRegion( + block->writes[0]->buffer, block->writes[0]->region.Map([rewrite_expr](const Range& r) { + return Range::FromMinExtent(rewrite_expr(r->min), rewrite_expr(r->extent)); + })); // create block to fill const pad values BufferStore store = Downcast(block->body); store.CopyOnWrite()->value = info.pad_value; - store.CopyOnWrite()->indices = MutateArray(store->indices, rewrite_expr); + store.CopyOnWrite()->indices = store->indices.Map(rewrite_expr); Block new_block(/*iter_vars=*/new_iter_vars, /*reads=*/{}, /*writes=*/{write_region}, /*name_hint=*/block->name_hint + "_pad_const", /*body=*/std::move(store)); @@ -307,7 +306,7 @@ static std::pair CreateInBoundBlock(const BlockRealizeNode* return analyzer->Simplify(Substitute(e, repl_dict)); }; auto rewrite_region = [rewrite_expr](const Region& region) { - return MutateArray(region, [rewrite_expr](const Range& r) { + return region.Map([rewrite_expr](const Range& r) { return Range::FromMinExtent(rewrite_expr(r->min), rewrite_expr(r->extent)); }); }; @@ -324,7 +323,7 @@ static std::pair CreateInBoundBlock(const BlockRealizeNode* // create new block realize node BufferStore store = Downcast(block->body); store.CopyOnWrite()->value = rewrite_expr(info.in_bound_value); - store.CopyOnWrite()->indices = MutateArray(store->indices, rewrite_expr); + store.CopyOnWrite()->indices = store->indices.Map(rewrite_expr); Block new_block(/*iter_vars=*/new_iter_vars, /*reads=*/reads, /*writes=*/writes, /*name_hint=*/block->name_hint, /*body=*/std::move(store)); PrimExpr new_predicate = rewrite_expr(info.in_bound_predicate); diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 1ebaf202d487..6832cdd1123c 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -148,12 +148,12 @@ Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) { }; // Step 1. Mutate `match_buffers`. If an old buffer appears as a source of MatchBufferRegion, - Array match_buffers = MutateArray(block->match_buffers, f_mutate_match_buffer); + Array match_buffers = block->match_buffers.Map(f_mutate_match_buffer); // Step 2. Mutate the read/write region. - Array reads = MutateArray(block->reads, f_mutate_read_write_region); - Array writes = MutateArray(block->writes, f_mutate_read_write_region); + Array reads = block->reads.Map(f_mutate_read_write_region); + Array writes = block->writes.Map(f_mutate_read_write_region); // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block. - Array alloc_buffers = MutateArray(block->alloc_buffers, f_mutate_alloc_buffers); + Array alloc_buffers = block->alloc_buffers.Map(f_mutate_alloc_buffers); // Step 4. Recursively mutate the block. Block mutated_block = Downcast(StmtMutator::VisitStmt_(block)); diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 455140c75c13..f49b6b2ace8e 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -400,8 +400,8 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { PrimExpr condition = this->VisitExpr(op->condition); - Array extents = op->extents; - extents.MutateByApply([this](const PrimExpr& extent) { return this->VisitExpr(extent); }); + Array extents = + op->extents.Map([this](const PrimExpr& extent) { return this->VisitExpr(extent); }); if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(GetRef(op), true); diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 5bde5cb90e2b..9b915da6290b 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -51,10 +51,10 @@ class MatchBufferLower : public StmtExprMutator { Stmt stmt = StmtExprMutator ::VisitStmt_(op); op = stmt.as(); ICHECK(op != nullptr); - Array reads = MutateArray( - op->reads, std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); - Array writes = MutateArray( - op->writes, std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); + Array reads = + op->reads.Map(std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); + Array writes = op->writes.Map( + std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); if (reads.same_as(op->reads) && writes.same_as(op->writes) && op->match_buffers.empty()) { return stmt; diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc index c717dc9b98f2..a185916a9a4c 100644 --- a/src/tir/transforms/renew_defs.cc +++ b/src/tir/transforms/renew_defs.cc @@ -96,18 +96,16 @@ class RenewDefMutator : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* op) final { // Step 0. Re-define Itervars - Array iter_vars = MutateArray( - op->iter_vars, std::bind(&RenewDefMutator::VisitIterVar, this, std::placeholders::_1)); + Array iter_vars = + op->iter_vars.Map(std::bind(&RenewDefMutator::VisitIterVar, this, std::placeholders::_1)); // Step 1. Re-define buffers allocate under the block - Array alloc_buffers = MutateArray( - op->alloc_buffers, + Array alloc_buffers = op->alloc_buffers.Map( std::bind(&RenewDefMutator::VisitBuffer, this, std::placeholders::_1, /*define=*/true)); // Step 2. Re-define match_buffers - Array match_buffers = - MutateArray(op->match_buffers, - std::bind(&RenewDefMutator::VisitMatchBuffer, this, std::placeholders::_1)); + Array match_buffers = op->match_buffers.Map( + std::bind(&RenewDefMutator::VisitMatchBuffer, this, std::placeholders::_1)); // Step 3. Visit body Stmt stmt = StmtExprMutator::VisitStmt_(op); @@ -115,10 +113,10 @@ class RenewDefMutator : public StmtExprMutator { ICHECK(op); // Step 4. Revisit access region - Array reads = MutateArray( - op->reads, std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); - Array writes = MutateArray( - op->writes, std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); + Array reads = + op->reads.Map(std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); + Array writes = + op->writes.Map(std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); // Step 5. Regenerate block. Since the defs are changed, we need to create a new block auto n = make_object(*op); @@ -203,9 +201,9 @@ class RenewDefMutator : public StmtExprMutator { // update data Var data = Downcast(redefine_if_is_var(buffer->data)); // update shape - Array shape = MutateArray(buffer->shape, redefine_if_is_var); + Array shape = buffer->shape.Map(redefine_if_is_var); // update strides - Array strides = MutateArray(buffer->strides, redefine_if_is_var); + Array strides = buffer->strides.Map(redefine_if_is_var); // update elem_offset PrimExpr elem_offset = redefine_if_is_var(buffer->elem_offset); @@ -242,10 +240,10 @@ class RenewDefMutator : public StmtExprMutator { return Downcast((*it).second); } Var data = Downcast(VisitExpr(buffer->data)); - Array shape = MutateArray( - buffer->shape, std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); - Array strides = MutateArray( - buffer->strides, std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); + Array shape = + buffer->shape.Map(std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); + Array strides = + buffer->strides.Map(std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); PrimExpr elem_offset = VisitExpr(buffer->elem_offset); auto n = make_object(*buffer.get()); @@ -276,9 +274,8 @@ class RenewDefMutator : public StmtExprMutator { BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) { Buffer buffer = VisitBuffer(buffer_region->buffer); - Array region = - MutateArray(buffer_region->region, - std::bind(&RenewDefMutator::VisitRange, this, std::placeholders::_1)); + Array region = buffer_region->region.Map( + std::bind(&RenewDefMutator::VisitRange, this, std::placeholders::_1)); if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) { return buffer_region; } else { diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 5c5a47e86a9a..3cc17847e69b 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -379,8 +379,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; - Array indices = op->indices; - indices.MutateByApply(fmutate); + Array indices = op->indices.Map(fmutate); if (!indices.same_as(op->indices)) { auto writer = load.CopyOnWrite(); @@ -428,8 +427,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op); auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; - Array indices = op->indices; - indices.MutateByApply(fmutate); + Array indices = op->indices.Map(fmutate); PrimExpr value = this->VisitExpr(op->value); From 95ad71167367bbf0180418c98323a347c0fbcac1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Sep 2022 10:02:49 -0500 Subject: [PATCH 3/6] Merge the Map and MutateInPlace implementations --- include/tvm/runtime/container/array.h | 169 +++++++++++++++++--------- 1 file changed, 113 insertions(+), 56 deletions(-) diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index 0589dfae5f83..934afb77e512 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -26,10 +26,12 @@ #include #include +#include #include #include #include "./base.h" +#include "./optional.h" namespace tvm { namespace runtime { @@ -248,6 +250,23 @@ class ArrayNode : public Object, public InplaceArrayBase { friend ObjectPtr make_object<>(); }; +/*! \brief Helper struct for type-checking + * + * is_valid_iterator::value will be true if IterType can + * be dereferenced into a type that can be stored in an Array, and + * false otherwise. + */ +template +struct is_valid_iterator + : std::bool_constant())>>>> {}; + +template +struct is_valid_iterator, IterType> : is_valid_iterator {}; + +template +inline constexpr bool is_valid_iterator_v = is_valid_iterator::value; + /*! * \brief Array, container representing a contiguous sequence of ObjectRefs. * @@ -595,21 +614,7 @@ class Array : public ObjectRef { */ template > Array Map(F fmap) const { - if constexpr (std::is_same_v) { - // Special case for outputs of the same type, may be able to use - // MutateByApply's in-place handling to avoid copying data, if - // the mapping function is the identity for all elements. - Array output = *this; - output.MutateByApply(fmap); - return output; - } else { - Array output; - output.reserve(size()); - for (T item : *this) { - output.push_back(fmap(std::move(item))); - } - return output; - } + return Array(MapHelper(data_, fmap)); } /*! @@ -618,48 +623,9 @@ class Array : public ObjectRef { * \tparam F the type of the mutation function. * \note This function performs copy on write optimization. */ - template + template >>> void MutateByApply(F fmutate) { - if (data_ == nullptr) { - return; - } - struct StackFrame { - ArrayNode* p; - ObjectRef* itr; - int64_t i; - int64_t size; - }; - std::unique_ptr s = std::make_unique(); - s->p = GetArrayNode(); - s->itr = s->p->MutableBegin(); - s->i = 0; - s->size = s->p->size_; - if (!data_.unique()) { - // Loop invariant: keeps iterating when - // 1) data is not unique - // 2) no elements are actually mutated yet - for (; s->i < s->size; ++s->i, ++s->itr) { - T new_elem = fmutate(DowncastNoCheck(*s->itr)); - // do nothing when there is no mutation - if (new_elem.same_as(*s->itr)) { - continue; - } - // loop invariant breaks when the first real mutation happens - // we copy the elements into a new unique array - ObjectPtr copy = ArrayNode::CopyFrom(s->p->capacity_, s->p); - s->itr = copy->MutableBegin() + (s->i++); - *s->itr++ = std::move(new_elem); - data_ = std::move(copy); - // make sure `data_` is unique and break - break; - } - } - // when execution comes to this line, it is guaranteed that either - // 1) i == size - // or 2) data_.unique() is true - for (; s->i < s->size; ++s->i, ++s->itr) { - *s->itr = std::move(fmutate(std::move(DowncastNoCheck(std::move(*s->itr))))); - } + data_ = MapHelper(std::move(data_), fmutate); } /*! @@ -744,6 +710,97 @@ class Array : public ObjectRef { } return static_cast(data_.get()); } + + /*! \brief Helper method for mutate/map + * + * A helper function used internally by both `Array::Map` and + * `Array::MutateInPlace`. Given an array of data, apply the + * mapping function to each element, returning the collected array. + * Applies both mutate-in-place and copy-on-write optimizations, if + * possible. + * + * \param data A pointer to the ArrayNode containing input data. + * Passed by value to allow for mutate-in-place optimizations. + * + * \param fmap The mapping function + * + * \tparam F The type of the mutation function. + * + * \tparam U The output type of the mutation function. Inferred + * from the callable type given. Must inherit from ObjectRef. + * + * \return The mapped array. Depending on whether mutate-in-place + * or copy-on-write optimizations were applicable, may be the same + * underlying array as the `data` parameter. + */ + template > + static ObjectPtr MapHelper(ObjectPtr data, F fmap) { + if (data == nullptr) { + return nullptr; + } + + ICHECK(data->IsInstance()); + + constexpr bool is_same_output_type = std::is_same_v; + + if constexpr (is_same_output_type) { + if (data.unique()) { + // Mutate-in-place path. Only allowed if the output type U is + // the same as type T, we have a mutable this*, and there are + // no other shared copies of the array. + auto arr = static_cast(data.get()); + for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) { + T mapped = fmap(DowncastNoCheck(std::move(*it))); + *it = std::move(mapped); + } + return data; + } + } + + constexpr bool compatible_types = is_valid_iterator_v || is_valid_iterator_v; + + ObjectPtr output = nullptr; + auto arr = static_cast(data.get()); + + auto it = arr->begin(); + if constexpr (compatible_types) { + // Copy-on-write path, if the output Array might be + // represented by the same underlying array as the existing + // Array. Typically, this is for functions that map `T` to + // `T`, but can also apply to functions that map `T` to + // `Optional`, or that map `T` to a subclass or superclass of + // `T`. + bool all_identical = true; + for (; it != arr->end(); it++) { + U mapped = fmap(DowncastNoCheck(*it)); + if (!mapped.same_as(*it)) { + all_identical = false; + output = ArrayNode::CreateRepeated(arr->size(), U()); + output->InitRange(0, arr->begin(), it); + output->SetItem(it - arr->begin(), std::move(mapped)); + break; + } + } + if (all_identical) { + return data; + } + } else { + // Path for incompatible types. The constexpr check for + // compatible types isn't strictly necessary, as the first + // mapped.same_as(*it) would return false, but we might as well + // avoid it altogether. + output = ArrayNode::CreateRepeated(arr->size(), U()); + } + + // Normal path for incompatible types, or post-copy path for + // copy-on-write instances. + for (; it != arr->end(); it++) { + U mapped = fmap(DowncastNoCheck(*it)); + output->SetItem(it - arr->begin(), std::move(mapped)); + } + + return output; + } }; /*! From b38e5b574275bda79396bcb1a7cbbfa2e512ca9b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 7 Sep 2022 11:37:10 -0500 Subject: [PATCH 4/6] Fix off-by-one error in MapHelper --- include/tvm/runtime/container/array.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index 934afb77e512..796595486cde 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -778,6 +778,7 @@ class Array : public ObjectRef { output = ArrayNode::CreateRepeated(arr->size(), U()); output->InitRange(0, arr->begin(), it); output->SetItem(it - arr->begin(), std::move(mapped)); + it++; break; } } From 0917cc1a10983b5540d79fed747e74a713024b30 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 19 Sep 2022 16:13:23 -0500 Subject: [PATCH 5/6] Updated with unit tests for Array::Map conversions --- include/tvm/runtime/container/array.h | 2 +- tests/cpp/container_test.cc | 135 ++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 1 deletion(-) diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index 796595486cde..ba72b27e1364 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -614,7 +614,7 @@ class Array : public ObjectRef { */ template > Array Map(F fmap) const { - return Array(MapHelper(data_, fmap)); + return Array(MapHelper(data_, fmap)); } /*! diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index f6c4fb4b67d6..d75a510d0c95 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -173,6 +173,141 @@ TEST(Array, Mutate) { ICHECK(list2[1].same_as(z)); } +TEST(Array, MutateInPlaceForUniqueReference) { + using namespace tvm; + Var x("x"); + Array arr{x, x}; + ICHECK(arr.unique()); + auto* before = arr.get(); + + arr.MutateByApply([](Var) { return Var("y"); }); + auto* after = arr.get(); + ICHECK_EQ(before, after); +} + +TEST(Array, CopyWhenMutatingNonUniqueReference) { + using namespace tvm; + Var x("x"); + Array arr{x, x}; + Array arr2 = arr; + + ICHECK(!arr.unique()); + auto* before = arr.get(); + + arr.MutateByApply([](Var) { return Var("y"); }); + auto* after = arr.get(); + ICHECK_NE(before, after); +} + +TEST(Array, Map) { + // Basic functionality + using namespace tvm; + Var x("x"); + Var y("y"); + Array var_arr{x, y}; + Array expr_arr = var_arr.Map([](Var var) -> PrimExpr { return var + 1; }); + + ICHECK_NE(var_arr.get(), expr_arr.get()); + ICHECK(expr_arr[0]->IsInstance()); + ICHECK(expr_arr[1]->IsInstance()); + ICHECK(expr_arr[0].as()->a.same_as(x)); + ICHECK(expr_arr[1].as()->a.same_as(y)); +} + +TEST(Array, MapToSameTypeWithoutCopy) { + // If the applied map doesn't alter the contents, we can avoid a + // copy. + using namespace tvm; + Var x("x"); + Var y("y"); + Array var_arr{x, y}; + Array var_arr2 = var_arr.Map([](Var var) { return var; }); + + ICHECK_EQ(var_arr.get(), var_arr2.get()); +} + +TEST(Array, MapToSameTypeWithCopy) { + // If the applied map does alter the contents, we need to make a + // copy. The loop in this test is to validate correct behavior + // regardless of where the first discrepancy occurs. + using namespace tvm; + Var x("x"); + Var y("y"); + Var z("z"); + Var replacement("replacement"); + for (size_t i = 0; i < 2; i++) { + Array var_arr{x, y, z}; + Var to_replace = var_arr[i]; + Array var_arr2 = + var_arr.Map([&](Var var) { return var.same_as(to_replace) ? replacement : var; }); + + ICHECK_NE(var_arr.get(), var_arr2.get()); + + // The original array is unchanged + ICHECK_EQ(var_arr.size(), 3); + ICHECK(var_arr[0].same_as(x)); + ICHECK(var_arr[1].same_as(y)); + + // The returned array has one of the elements replaced. + ICHECK_EQ(var_arr2.size(), 3); + ICHECK(var_arr2[i].same_as(replacement)); + ICHECK(i == 0 || var_arr2[0].same_as(x)); + ICHECK(i == 1 || var_arr2[1].same_as(y)); + ICHECK(i == 2 || var_arr2[2].same_as(z)); + } +} + +TEST(Array, MapToSuperclassWithoutCopy) { + // If a map is converting to a superclass, and the mapping function + // array doesn't change the value other than a cast, we can avoid a + // copy. + using namespace tvm; + Var x("x"); + Var y("y"); + Array var_arr{x, y}; + Array expr_arr = var_arr.Map([](Var var) { return PrimExpr(var); }); + + ICHECK_EQ(var_arr.get(), expr_arr.get()); +} + +TEST(Array, MapToSubclassWithoutCopy) { + // If a map is converting to a subclass, and the mapped array + // happens to only contain instances of that subclass, we can + // able to avoid a copy. + using namespace tvm; + Var x("x"); + Var y("y"); + Array expr_arr{x, y}; + Array var_arr = expr_arr.Map([](PrimExpr expr) -> Var { return Downcast(expr); }); + + ICHECK_EQ(var_arr.get(), expr_arr.get()); +} + +TEST(Array, MapToOptionalWithoutCopy) { + // Optional and T both have the same T::ContainerType, just with + // different interfaces for handling `T::data_ == nullptr`. + using namespace tvm; + Var x("x"); + Var y("y"); + Array var_arr{x, y}; + Array> opt_arr = var_arr.Map([](Var var) { return Optional(var); }); + + ICHECK_EQ(var_arr.get(), opt_arr.get()); +} + +TEST(Array, MapFromOptionalWithoutCopy) { + // Optional and T both have the same T::ContainerType, just with + // different interfaces for handling `T::data_ == nullptr`. + using namespace tvm; + Var x("x"); + Var y("y"); + Array> opt_arr{x, y}; + Array var_arr = + opt_arr.Map([](Optional var) { return var.value_or(Var("undefined")); }); + + ICHECK_EQ(var_arr.get(), opt_arr.get()); +} + TEST(Array, Iterator) { using namespace tvm; Array array{1, 2, 3}; From 2285ea1d9690f04831b3bd5198a851905471c470 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 20 Sep 2022 08:50:57 -0500 Subject: [PATCH 6/6] Improved comments explaining the copy-on-write in MapHelper --- include/tvm/runtime/container/array.h | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index ba72b27e1364..11bacb18e92c 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -774,6 +774,11 @@ class Array : public ObjectRef { for (; it != arr->end(); it++) { U mapped = fmap(DowncastNoCheck(*it)); if (!mapped.same_as(*it)) { + // At least one mapped element is different than the + // original. Therefore, prepare the output array, + // consisting of any previous elements that had mapped to + // themselves (if any), and the element that didn't map to + // itself. all_identical = false; output = ArrayNode::CreateRepeated(arr->size(), U()); output->InitRange(0, arr->begin(), it); @@ -795,6 +800,21 @@ class Array : public ObjectRef { // Normal path for incompatible types, or post-copy path for // copy-on-write instances. + // + // If the types are incompatible, then at this point `output` is + // empty, and `it` points to the first element of the input. + // + // If the types were compatible, then at this point `output` + // contains zero or more elements that mapped to themselves + // followed by the first element that does not map to itself, and + // `it` points to the element just after the first element that + // does not map to itself. Because at least one element has been + // changed, we no longer have the opportunity to avoid a copy, so + // we don't need to check the result. + // + // In both cases, `it` points to the next element to be processed, + // so we can either start or resume the iteration from that point, + // with no further checks on the result. for (; it != arr->end(); it++) { U mapped = fmap(DowncastNoCheck(*it)); output->SetItem(it - arr->begin(), std::move(mapped));