Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Containers] Add Array::Map #12692

Merged
merged 7 commits into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 157 additions & 41 deletions include/tvm/runtime/container/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@

#include <algorithm>
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>

#include "./base.h"
#include "./optional.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -248,6 +250,23 @@ class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> {
friend ObjectPtr<ArrayNode> make_object<>();
};

/*! \brief Helper struct for type-checking
*
* is_valid_iterator<T,IterType>::value will be true if IterType can
* be dereferenced into a type that can be stored in an Array<T>, and
* false otherwise.
*/
template <typename T, typename IterType>
struct is_valid_iterator
: std::bool_constant<std::is_base_of_v<
T, std::remove_cv_t<std::remove_reference_t<decltype(*std::declval<IterType>())>>>> {};

template <typename T, typename IterType>
struct is_valid_iterator<Optional<T>, IterType> : is_valid_iterator<T, IterType> {};

template <typename T, typename IterType>
inline constexpr bool is_valid_iterator_v = is_valid_iterator<T, IterType>::value;

/*!
* \brief Array, container representing a contiguous sequence of ObjectRefs.
*
Expand Down Expand Up @@ -574,54 +593,39 @@ class Array : public ObjectRef {
/*! \return The underlying ArrayNode */
ArrayNode* GetArrayNode() const { return static_cast<ArrayNode*>(data_.get()); }

/*!
* \brief Helper function to apply a map function onto the array.
*
* \param fmap The transformation function T -> U.
*
* \tparam F The type of the mutation function.
*
* \tparam U The type of the returned array, inferred from the
* return type of F. If overridden by the user, must be something
* that is convertible from the return type of F.
*
* \note This function performs copy on write optimization. If
* `fmap` returns an object of type `T`, and all elements of the
* array are mapped to themselves, then the returned array will be
* the same as the original, and reference counts of the elements in
* the array will not be incremented.
*
* \return The transformed array.
*/
template <typename F, typename U = std::invoke_result_t<F, T>>
Array<U> Map(F fmap) const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was curious if we could unify and migrate MutateByApply into Map?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly, and that would allow for avoiding copies in a few additional cases (e.g. map from T to Optional<T>, or to a superclass of T) that aren't currently handled. I'll take a quick stab at it and see if I can unify the two.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you on the suggestion, and it ended up being much cleaner that way. Both Map and MutateByApply are now implemented in terms of the same underlying helper function. The helper function applies both the mutate-in-place and copy-on-write optimizations, with if constexpr type checks to avoid attempting the optimization if they wouldn't be possible.

return Array<U>(MapHelper(data_, fmap));
}

/*!
* \brief Helper function to apply fmutate to mutate an array.
* \param fmutate The transformation function T -> T.
* \tparam F the type of the mutation function.
* \note This function performs copy on write optimization.
*/
template <typename F>
template <typename F, typename = std::enable_if_t<std::is_same_v<T, std::invoke_result_t<F, T>>>>
void MutateByApply(F fmutate) {
if (data_ == nullptr) {
return;
}
struct StackFrame {
ArrayNode* p;
ObjectRef* itr;
int64_t i;
int64_t size;
};
std::unique_ptr<StackFrame> s = std::make_unique<StackFrame>();
s->p = GetArrayNode();
s->itr = s->p->MutableBegin();
s->i = 0;
s->size = s->p->size_;
if (!data_.unique()) {
// Loop invariant: keeps iterating when
// 1) data is not unique
// 2) no elements are actually mutated yet
for (; s->i < s->size; ++s->i, ++s->itr) {
T new_elem = fmutate(DowncastNoCheck<T>(*s->itr));
// do nothing when there is no mutation
if (new_elem.same_as(*s->itr)) {
continue;
}
// loop invariant breaks when the first real mutation happens
// we copy the elements into a new unique array
ObjectPtr<ArrayNode> copy = ArrayNode::CopyFrom(s->p->capacity_, s->p);
s->itr = copy->MutableBegin() + (s->i++);
*s->itr++ = std::move(new_elem);
data_ = std::move(copy);
// make sure `data_` is unique and break
break;
}
}
// when execution comes to this line, it is guaranteed that either
// 1) i == size
// or 2) data_.unique() is true
for (; s->i < s->size; ++s->i, ++s->itr) {
*s->itr = std::move(fmutate(std::move(DowncastNoCheck<T>(std::move(*s->itr)))));
}
data_ = MapHelper(std::move(data_), fmutate);
}

/*!
Expand Down Expand Up @@ -706,6 +710,118 @@ class Array : public ObjectRef {
}
return static_cast<ArrayNode*>(data_.get());
}

/*! \brief Helper method for mutate/map
*
* A helper function used internally by both `Array::Map` and
* `Array::MutateInPlace`. Given an array of data, apply the
* mapping function to each element, returning the collected array.
* Applies both mutate-in-place and copy-on-write optimizations, if
* possible.
*
* \param data A pointer to the ArrayNode containing input data.
* Passed by value to allow for mutate-in-place optimizations.
*
* \param fmap The mapping function
*
* \tparam F The type of the mutation function.
*
* \tparam U The output type of the mutation function. Inferred
* from the callable type given. Must inherit from ObjectRef.
*
* \return The mapped array. Depending on whether mutate-in-place
* or copy-on-write optimizations were applicable, may be the same
* underlying array as the `data` parameter.
*/
template <typename F, typename U = std::invoke_result_t<F, T>>
static ObjectPtr<Object> MapHelper(ObjectPtr<Object> data, F fmap) {
if (data == nullptr) {
return nullptr;
}

ICHECK(data->IsInstance<ArrayNode>());

constexpr bool is_same_output_type = std::is_same_v<T, U>;

if constexpr (is_same_output_type) {
if (data.unique()) {
// Mutate-in-place path. Only allowed if the output type U is
// the same as type T, we have a mutable this*, and there are
// no other shared copies of the array.
auto arr = static_cast<ArrayNode*>(data.get());
for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) {
T mapped = fmap(DowncastNoCheck<T>(std::move(*it)));
*it = std::move(mapped);
}
return data;
}
}

constexpr bool compatible_types = is_valid_iterator_v<T, U*> || is_valid_iterator_v<U, T*>;

ObjectPtr<ArrayNode> output = nullptr;
auto arr = static_cast<ArrayNode*>(data.get());

auto it = arr->begin();
if constexpr (compatible_types) {
// Copy-on-write path, if the output Array<U> might be
// represented by the same underlying array as the existing
// Array<T>. Typically, this is for functions that map `T` to
// `T`, but can also apply to functions that map `T` to
// `Optional<T>`, or that map `T` to a subclass or superclass of
// `T`.
bool all_identical = true;
for (; it != arr->end(); it++) {
U mapped = fmap(DowncastNoCheck<T>(*it));
if (!mapped.same_as(*it)) {
// At least one mapped element is different than the
// original. Therefore, prepare the output array,
// consisting of any previous elements that had mapped to
// themselves (if any), and the element that didn't map to
// itself.
all_identical = false;
output = ArrayNode::CreateRepeated(arr->size(), U());
output->InitRange(0, arr->begin(), it);
output->SetItem(it - arr->begin(), std::move(mapped));
it++;
break;
}
}
if (all_identical) {
return data;
}
} else {
// Path for incompatible types. The constexpr check for
// compatible types isn't strictly necessary, as the first
// mapped.same_as(*it) would return false, but we might as well
// avoid it altogether.
output = ArrayNode::CreateRepeated(arr->size(), U());
}

// Normal path for incompatible types, or post-copy path for
// copy-on-write instances.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will be left over on the copy-on-write instance? Will there be some items that are incompatible? How are those guaranteed to be at the end?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will be left over on the copy-on-write instance?

If we have compatible types, and we've reached this point, we've found at least one element for which the mapped.same_as(*it) check on line 776 has failed. In that case, output will contain everything in the range [arr->begin(), it). That is, output contains all elements that are identical, and the first non-identical element. it will point to the next element that should be transformed, and so the next loop over it can continue where the first loop left off.

Will there be some items that are incompatible?

It's entirely possible, either at compile-time or at runtime. For example, I could have an Array<PrimExpr> buffer_shape and map it to allowed ranges buffer_shape.Map([](PrimExpr expr) { return Range::FromMinExtent(0, expr);});, which would be incompatible and identified as such at compile-time. In that case, the if constexpr could identify that they cannot be represented by the same underlying array, and can skip the attempts to do so altogether.

If a type is incompatible at runtime, then it will also fail the mapped.same_as(*it) check on line 776. So if I have an Array<Var> being mapped to Array<PrimExpr> with var_array.Map([&](Var var) { return var.same_as(to_replace) ? replace_with : var;});, it may or may not be compatible, depending on whether to_replace shows up in the array.

How are those guaranteed to be at the end?

Incompatible items may occur at any point in the mapped output, even at the very first iteration. In that case, the commands executed in the conditional on !mapped.same_as(*it) are the same as would be executed up through the first iteration of the mapping loop.

// Same as the else branch on `compatible_types`
output = ArrayNode::CreateRepeated(arr->size(), U());

// For the first iteration, it is `arr->begin()`, so this would be an
// empty range [begin, begin), nothing is initialized, and this
// statement has no effect.
output->InitRange(0, arr->begin(), it);

// The newly mapped item is stored to the first location of the output.
output->SetItem(it - arr->begin(), std::move(mapped));

// The loop increment that would have happened
it++;

// `it` now points to the second element of the input, and we have one
// mapped element in the output.  We're now ready to start the second
// loop, just at the second iteration instead of the first.

Essentially, we only need to check for identical return values up until we find a single non-identical element, at which point we know that we can't avoid the copy anyways. But once we reach the first non-identical value, we don't need to repeat the function calls up to that point, because we know that everything is either identical (and can therefore be copied from the input) or is non-identical (is which case it is the first such non-identical value).

Copy link
Collaborator

@janetsc janetsc Sep 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this explanation! Maybe it would be helpful to others as well to summarize this in the comment block on 796...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem, and updated!

//
// If the types are incompatible, then at this point `output` is
// empty, and `it` points to the first element of the input.
//
// If the types were compatible, then at this point `output`
// contains zero or more elements that mapped to themselves
// followed by the first element that does not map to itself, and
// `it` points to the element just after the first element that
// does not map to itself. Because at least one element has been
// changed, we no longer have the opportunity to avoid a copy, so
// we don't need to check the result.
//
// In both cases, `it` points to the next element to be processed,
// so we can either start or resume the iteration from that point,
// with no further checks on the result.
for (; it != arr->end(); it++) {
U mapped = fmap(DowncastNoCheck<T>(*it));
output->SetItem(it - arr->begin(), std::move(mapped));
}

return output;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment - Can you add a unit test to exercise the edge cases in MapHelper?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Certainly, and thank you for pointing that out! There are some existing tests in container_test.cc, along with a large amount of usage when lowering TIR, but no tests that would specifically point to these edge cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests added for each of the compatible types, to validate that copies are avoided, and to ensure correct fail-through behavior when a copy is required. A double thanks for requesting it, as it also caught a type conversion error that I had missed.

};

/*!
Expand Down
9 changes: 1 addition & 8 deletions src/ir/type_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,7 @@ Type TypeMutator::VisitType(const Type& t) {
Array<Type> TypeMutator::MutateArray(Array<Type> arr) {
// The array will do copy on write
// If no changes are made, the original array will be returned.
for (size_t i = 0; i < arr.size(); ++i) {
Type ty = arr[i];
Type new_ty = VisitType(ty);
if (!ty.same_as(new_ty)) {
arr.Set(i, new_ty);
}
}
return arr;
return arr.Map([this](const Type& ty) { return VisitType(ty); });
}

Type TypeMutator::VisitType_(const TypeVarNode* op) { return GetRef<TypeVar>(op); }
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
// TensorIR will not allow Tensor data structure
if (value->IsInstance<ArrayNode>()) {
const auto array_value = Downcast<Array<ObjectRef>>(value);
annotations.Set(key, MutateArray(array_value, mutate_attr));
annotations.Set(key, array_value.Map(mutate_attr));
} else {
annotations.Set(key, mutate_attr(value));
}
Expand Down
5 changes: 2 additions & 3 deletions src/tir/analysis/device_constraint_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,8 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
}

template <typename T>
Array<T> VisitItems(Array<T> items) {
items.MutateByApply([this](const T& item) { return VisitItem(item.get()); }); // copy-on-write
return items;
Array<T> VisitItems(const Array<T>& items) {
return items.Map([this](T item) -> T { return VisitItem(item.get()); });
}

Stmt VisitStmt_(const BlockNode* block_node) final {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
ICHECK(n != nullptr);
arith::Analyzer ana;
begins = SimplifyArray(&ana, begins);
Array<PrimExpr> elem_offset = n->ElemOffset(begins);
elem_offset.MutateByApply([&](const PrimExpr& expr) { return ana.Simplify(expr); });
Array<PrimExpr> elem_offset =
n->ElemOffset(begins).Map([&](const PrimExpr& expr) { return ana.Simplify(expr); });

Array<PrimExpr> strides = n->strides;
if (strides.size() == 0) {
Expand Down
3 changes: 1 addition & 2 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -994,8 +994,7 @@ Array<PrimExpr> CommReducerNode::operator()(Array<PrimExpr> a, Array<PrimExpr> b
value_map.Set(lhs[i], a[i]);
value_map.Set(rhs[i], b[i]);
}
auto ret = this->result;
ret.MutateByApply([&value_map](const PrimExpr& e) { return Substitute(e, value_map); });
auto ret = this->result.Map([&value_map](const PrimExpr& e) { return Substitute(e, value_map); });
return ret;
}

Expand Down
14 changes: 7 additions & 7 deletions src/tir/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) {

PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> indices = MutateArray(op->indices, fmutate);
Array<PrimExpr> indices = op->indices.Map(fmutate);
if (indices.same_as(op->indices)) {
return GetRef<PrimExpr>(op);
} else {
Expand All @@ -142,7 +142,7 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {

PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> indices = MutateArray(op->indices, fmutate);
Array<PrimExpr> indices = op->indices.Map(fmutate);
if (indices.same_as(op->indices)) {
return GetRef<PrimExpr>(op);
} else {
Expand All @@ -162,7 +162,7 @@ PrimExpr ExprMutator::VisitExpr_(const LetNode* op) {

PrimExpr ExprMutator::VisitExpr_(const CallNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> args = MutateArray(op->args, fmutate);
Array<PrimExpr> args = op->args.Map(fmutate);

if (args.same_as(op->args)) {
return GetRef<PrimExpr>(op);
Expand Down Expand Up @@ -218,11 +218,11 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) {
return IterVar(Range::FromMinExtent(min, extent), v->var, v->iter_type, v->thread_tag);
}
};
Array<IterVar> axis = MutateArray(op->axis, fitervar);
Array<IterVar> axis = op->axis.Map(fitervar);

auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> source = MutateArray(op->source, fexpr);
Array<PrimExpr> init = MutateArray(op->init, fexpr);
Array<PrimExpr> source = op->source.Map(fexpr);
Array<PrimExpr> init = op->init.Map(fexpr);

PrimExpr condition = this->VisitExpr(op->condition);

Expand Down Expand Up @@ -285,7 +285,7 @@ PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) {

PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) {
auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
auto vectors = MutateArray(op->vectors, fexpr);
auto vectors = op->vectors.Map(fexpr);
if (vectors.same_as(op->vectors)) {
return GetRef<PrimExpr>(op);
} else {
Expand Down
3 changes: 1 addition & 2 deletions src/tir/ir/functor_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ inline void VisitArray(const Array<T>& arr, F fvisit) {

template <typename T, typename F>
inline Array<T> MutateArray(Array<T> arr, F fmutate) {
arr.MutateByApply(fmutate);
return arr;
return arr.Map(fmutate);
}

} // namespace tir
Expand Down
5 changes: 2 additions & 3 deletions src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,8 @@ Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices,
analyzer = &local_analyzer;
}

Array<PrimExpr> output = final_indices;
output.MutateByApply(
[&](const PrimExpr& index) { return analyzer->Simplify(Substitute(index, vmap)); });
Array<PrimExpr> output = final_indices.Map(
[&](PrimExpr index) { return analyzer->Simplify(Substitute(std::move(index), vmap)); });

return output;
}
Expand Down
Loading