Skip to content

Commit

Permalink
[Containers] Add Array::Map (#12692)
Browse files Browse the repository at this point in the history
* [Containers] Add Array::Map

Previously, an in-place mutation could be applied to an array using
`Array::MutateByApply`, but this couldn't be used for transformations
that return a new array, or for transformations that return a new
type.

The commit adds `Array::Map`, which can map to any `ObjectRef`
subclass.  For mappings that return the same type, this is done by
delegating to `Array::MutateByApply`, to take advantage of the same
copy-on-write behavior.

* [Refactor] Use Array::Map where possible

With the new `Array::Map` functionality, many places that previously
used explicit loops or `tvm::tir::MutateArray` can be cleaned.

* Merge the Map and MutateInPlace implementations

* Fix off-by-one error in MapHelper

* Updated with unit tests for Array::Map conversions

* Improved comments explaining the copy-on-write in MapHelper
  • Loading branch information
Lunderberg authored Sep 20, 2022
1 parent 5dfa8da commit 534378b
Show file tree
Hide file tree
Showing 18 changed files with 353 additions and 125 deletions.
198 changes: 157 additions & 41 deletions include/tvm/runtime/container/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@

#include <algorithm>
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>

#include "./base.h"
#include "./optional.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -248,6 +250,23 @@ class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> {
friend ObjectPtr<ArrayNode> make_object<>();
};

/*! \brief Helper struct for type-checking
*
* is_valid_iterator<T,IterType>::value will be true if IterType can
* be dereferenced into a type that can be stored in an Array<T>, and
* false otherwise.
*/
template <typename T, typename IterType>
struct is_valid_iterator
: std::bool_constant<std::is_base_of_v<
T, std::remove_cv_t<std::remove_reference_t<decltype(*std::declval<IterType>())>>>> {};

template <typename T, typename IterType>
struct is_valid_iterator<Optional<T>, IterType> : is_valid_iterator<T, IterType> {};

template <typename T, typename IterType>
inline constexpr bool is_valid_iterator_v = is_valid_iterator<T, IterType>::value;

/*!
* \brief Array, container representing a contiguous sequence of ObjectRefs.
*
Expand Down Expand Up @@ -574,54 +593,39 @@ class Array : public ObjectRef {
/*! \return The underlying ArrayNode */
ArrayNode* GetArrayNode() const { return static_cast<ArrayNode*>(data_.get()); }

/*!
* \brief Helper function to apply a map function onto the array.
*
* \param fmap The transformation function T -> U.
*
* \tparam F The type of the mutation function.
*
* \tparam U The type of the returned array, inferred from the
* return type of F. If overridden by the user, must be something
* that is convertible from the return type of F.
*
* \note This function performs copy on write optimization. If
* `fmap` returns an object of type `T`, and all elements of the
* array are mapped to themselves, then the returned array will be
* the same as the original, and reference counts of the elements in
* the array will not be incremented.
*
* \return The transformed array.
*/
template <typename F, typename U = std::invoke_result_t<F, T>>
Array<U> Map(F fmap) const {
return Array<U>(MapHelper(data_, fmap));
}

/*!
* \brief Helper function to apply fmutate to mutate an array.
* \param fmutate The transformation function T -> T.
* \tparam F the type of the mutation function.
* \note This function performs copy on write optimization.
*/
template <typename F>
template <typename F, typename = std::enable_if_t<std::is_same_v<T, std::invoke_result_t<F, T>>>>
void MutateByApply(F fmutate) {
if (data_ == nullptr) {
return;
}
struct StackFrame {
ArrayNode* p;
ObjectRef* itr;
int64_t i;
int64_t size;
};
std::unique_ptr<StackFrame> s = std::make_unique<StackFrame>();
s->p = GetArrayNode();
s->itr = s->p->MutableBegin();
s->i = 0;
s->size = s->p->size_;
if (!data_.unique()) {
// Loop invariant: keeps iterating when
// 1) data is not unique
// 2) no elements are actually mutated yet
for (; s->i < s->size; ++s->i, ++s->itr) {
T new_elem = fmutate(DowncastNoCheck<T>(*s->itr));
// do nothing when there is no mutation
if (new_elem.same_as(*s->itr)) {
continue;
}
// loop invariant breaks when the first real mutation happens
// we copy the elements into a new unique array
ObjectPtr<ArrayNode> copy = ArrayNode::CopyFrom(s->p->capacity_, s->p);
s->itr = copy->MutableBegin() + (s->i++);
*s->itr++ = std::move(new_elem);
data_ = std::move(copy);
// make sure `data_` is unique and break
break;
}
}
// when execution comes to this line, it is guaranteed that either
// 1) i == size
// or 2) data_.unique() is true
for (; s->i < s->size; ++s->i, ++s->itr) {
*s->itr = std::move(fmutate(std::move(DowncastNoCheck<T>(std::move(*s->itr)))));
}
data_ = MapHelper(std::move(data_), fmutate);
}

/*!
Expand Down Expand Up @@ -706,6 +710,118 @@ class Array : public ObjectRef {
}
return static_cast<ArrayNode*>(data_.get());
}

/*! \brief Helper method for mutate/map
*
* A helper function used internally by both `Array::Map` and
* `Array::MutateInPlace`. Given an array of data, apply the
* mapping function to each element, returning the collected array.
* Applies both mutate-in-place and copy-on-write optimizations, if
* possible.
*
* \param data A pointer to the ArrayNode containing input data.
* Passed by value to allow for mutate-in-place optimizations.
*
* \param fmap The mapping function
*
* \tparam F The type of the mutation function.
*
* \tparam U The output type of the mutation function. Inferred
* from the callable type given. Must inherit from ObjectRef.
*
* \return The mapped array. Depending on whether mutate-in-place
* or copy-on-write optimizations were applicable, may be the same
* underlying array as the `data` parameter.
*/
template <typename F, typename U = std::invoke_result_t<F, T>>
static ObjectPtr<Object> MapHelper(ObjectPtr<Object> data, F fmap) {
if (data == nullptr) {
return nullptr;
}

ICHECK(data->IsInstance<ArrayNode>());

constexpr bool is_same_output_type = std::is_same_v<T, U>;

if constexpr (is_same_output_type) {
if (data.unique()) {
// Mutate-in-place path. Only allowed if the output type U is
// the same as type T, we have a mutable this*, and there are
// no other shared copies of the array.
auto arr = static_cast<ArrayNode*>(data.get());
for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) {
T mapped = fmap(DowncastNoCheck<T>(std::move(*it)));
*it = std::move(mapped);
}
return data;
}
}

constexpr bool compatible_types = is_valid_iterator_v<T, U*> || is_valid_iterator_v<U, T*>;

ObjectPtr<ArrayNode> output = nullptr;
auto arr = static_cast<ArrayNode*>(data.get());

auto it = arr->begin();
if constexpr (compatible_types) {
// Copy-on-write path, if the output Array<U> might be
// represented by the same underlying array as the existing
// Array<T>. Typically, this is for functions that map `T` to
// `T`, but can also apply to functions that map `T` to
// `Optional<T>`, or that map `T` to a subclass or superclass of
// `T`.
bool all_identical = true;
for (; it != arr->end(); it++) {
U mapped = fmap(DowncastNoCheck<T>(*it));
if (!mapped.same_as(*it)) {
// At least one mapped element is different than the
// original. Therefore, prepare the output array,
// consisting of any previous elements that had mapped to
// themselves (if any), and the element that didn't map to
// itself.
all_identical = false;
output = ArrayNode::CreateRepeated(arr->size(), U());
output->InitRange(0, arr->begin(), it);
output->SetItem(it - arr->begin(), std::move(mapped));
it++;
break;
}
}
if (all_identical) {
return data;
}
} else {
// Path for incompatible types. The constexpr check for
// compatible types isn't strictly necessary, as the first
// mapped.same_as(*it) would return false, but we might as well
// avoid it altogether.
output = ArrayNode::CreateRepeated(arr->size(), U());
}

// Normal path for incompatible types, or post-copy path for
// copy-on-write instances.
//
// If the types are incompatible, then at this point `output` is
// empty, and `it` points to the first element of the input.
//
// If the types were compatible, then at this point `output`
// contains zero or more elements that mapped to themselves
// followed by the first element that does not map to itself, and
// `it` points to the element just after the first element that
// does not map to itself. Because at least one element has been
// changed, we no longer have the opportunity to avoid a copy, so
// we don't need to check the result.
//
// In both cases, `it` points to the next element to be processed,
// so we can either start or resume the iteration from that point,
// with no further checks on the result.
for (; it != arr->end(); it++) {
U mapped = fmap(DowncastNoCheck<T>(*it));
output->SetItem(it - arr->begin(), std::move(mapped));
}

return output;
}
};

/*!
Expand Down
9 changes: 1 addition & 8 deletions src/ir/type_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,7 @@ Type TypeMutator::VisitType(const Type& t) {
Array<Type> TypeMutator::MutateArray(Array<Type> arr) {
// The array will do copy on write
// If no changes are made, the original array will be returned.
for (size_t i = 0; i < arr.size(); ++i) {
Type ty = arr[i];
Type new_ty = VisitType(ty);
if (!ty.same_as(new_ty)) {
arr.Set(i, new_ty);
}
}
return arr;
return arr.Map([this](const Type& ty) { return VisitType(ty); });
}

Type TypeMutator::VisitType_(const TypeVarNode* op) { return GetRef<TypeVar>(op); }
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
// TensorIR will not allow Tensor data structure
if (value->IsInstance<ArrayNode>()) {
const auto array_value = Downcast<Array<ObjectRef>>(value);
annotations.Set(key, MutateArray(array_value, mutate_attr));
annotations.Set(key, array_value.Map(mutate_attr));
} else {
annotations.Set(key, mutate_attr(value));
}
Expand Down
5 changes: 2 additions & 3 deletions src/tir/analysis/device_constraint_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
}

template <typename T>
Array<T> VisitItems(Array<T> items) {
items.MutateByApply([this](const T& item) { return VisitItem(item.get()); }); // copy-on-write
return items;
Array<T> VisitItems(const Array<T>& items) {
return items.Map([this](T item) -> T { return VisitItem(item.get()); });
}

Stmt VisitStmt_(const BlockNode* block_node) final {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
ICHECK(n != nullptr);
arith::Analyzer ana;
begins = SimplifyArray(&ana, begins);
Array<PrimExpr> elem_offset = n->ElemOffset(begins);
elem_offset.MutateByApply([&](const PrimExpr& expr) { return ana.Simplify(expr); });
Array<PrimExpr> elem_offset =
n->ElemOffset(begins).Map([&](const PrimExpr& expr) { return ana.Simplify(expr); });

Array<PrimExpr> strides = n->strides;
if (strides.size() == 0) {
Expand Down
3 changes: 1 addition & 2 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -994,8 +994,7 @@ Array<PrimExpr> CommReducerNode::operator()(Array<PrimExpr> a, Array<PrimExpr> b
value_map.Set(lhs[i], a[i]);
value_map.Set(rhs[i], b[i]);
}
auto ret = this->result;
ret.MutateByApply([&value_map](const PrimExpr& e) { return Substitute(e, value_map); });
auto ret = this->result.Map([&value_map](const PrimExpr& e) { return Substitute(e, value_map); });
return ret;
}

Expand Down
14 changes: 7 additions & 7 deletions src/tir/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) {

PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> indices = MutateArray(op->indices, fmutate);
Array<PrimExpr> indices = op->indices.Map(fmutate);
if (indices.same_as(op->indices)) {
return GetRef<PrimExpr>(op);
} else {
Expand All @@ -142,7 +142,7 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {

PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> indices = MutateArray(op->indices, fmutate);
Array<PrimExpr> indices = op->indices.Map(fmutate);
if (indices.same_as(op->indices)) {
return GetRef<PrimExpr>(op);
} else {
Expand All @@ -162,7 +162,7 @@ PrimExpr ExprMutator::VisitExpr_(const LetNode* op) {

PrimExpr ExprMutator::VisitExpr_(const CallNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> args = MutateArray(op->args, fmutate);
Array<PrimExpr> args = op->args.Map(fmutate);

if (args.same_as(op->args)) {
return GetRef<PrimExpr>(op);
Expand Down Expand Up @@ -218,11 +218,11 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) {
return IterVar(Range::FromMinExtent(min, extent), v->var, v->iter_type, v->thread_tag);
}
};
Array<IterVar> axis = MutateArray(op->axis, fitervar);
Array<IterVar> axis = op->axis.Map(fitervar);

auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> source = MutateArray(op->source, fexpr);
Array<PrimExpr> init = MutateArray(op->init, fexpr);
Array<PrimExpr> source = op->source.Map(fexpr);
Array<PrimExpr> init = op->init.Map(fexpr);

PrimExpr condition = this->VisitExpr(op->condition);

Expand Down Expand Up @@ -285,7 +285,7 @@ PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) {

PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) {
auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
auto vectors = MutateArray(op->vectors, fexpr);
auto vectors = op->vectors.Map(fexpr);
if (vectors.same_as(op->vectors)) {
return GetRef<PrimExpr>(op);
} else {
Expand Down
3 changes: 1 addition & 2 deletions src/tir/ir/functor_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ inline void VisitArray(const Array<T>& arr, F fvisit) {

template <typename T, typename F>
inline Array<T> MutateArray(Array<T> arr, F fmutate) {
arr.MutateByApply(fmutate);
return arr;
return arr.Map(fmutate);
}

} // namespace tir
Expand Down
5 changes: 2 additions & 3 deletions src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,8 @@ Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices,
analyzer = &local_analyzer;
}

Array<PrimExpr> output = final_indices;
output.MutateByApply(
[&](const PrimExpr& index) { return analyzer->Simplify(Substitute(index, vmap)); });
Array<PrimExpr> output = final_indices.Map(
[&](PrimExpr index) { return analyzer->Simplify(Substitute(std::move(index), vmap)); });

return output;
}
Expand Down
Loading

0 comments on commit 534378b

Please sign in to comment.