-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Containers] Add Array::Map #12692
[Containers] Add Array::Map #12692
Changes from all commits
b079ab6
6559231
95ad711
b38e5b5
0917cc1
2285ea1
c7afe70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,10 +26,12 @@ | |
|
||
#include <algorithm> | ||
#include <memory> | ||
#include <type_traits> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "./base.h" | ||
#include "./optional.h" | ||
|
||
namespace tvm { | ||
namespace runtime { | ||
|
@@ -248,6 +250,23 @@ class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> { | |
friend ObjectPtr<ArrayNode> make_object<>(); | ||
}; | ||
|
||
/*! \brief Helper struct for type-checking | ||
* | ||
* is_valid_iterator<T,IterType>::value will be true if IterType can | ||
* be dereferenced into a type that can be stored in an Array<T>, and | ||
* false otherwise. | ||
*/ | ||
template <typename T, typename IterType> | ||
struct is_valid_iterator | ||
: std::bool_constant<std::is_base_of_v< | ||
T, std::remove_cv_t<std::remove_reference_t<decltype(*std::declval<IterType>())>>>> {}; | ||
|
||
template <typename T, typename IterType> | ||
struct is_valid_iterator<Optional<T>, IterType> : is_valid_iterator<T, IterType> {}; | ||
|
||
template <typename T, typename IterType> | ||
inline constexpr bool is_valid_iterator_v = is_valid_iterator<T, IterType>::value; | ||
|
||
/*! | ||
* \brief Array, container representing a contiguous sequence of ObjectRefs. | ||
* | ||
|
@@ -574,54 +593,39 @@ class Array : public ObjectRef { | |
/*! \return The underlying ArrayNode */ | ||
ArrayNode* GetArrayNode() const { return static_cast<ArrayNode*>(data_.get()); } | ||
|
||
/*! | ||
* \brief Helper function to apply a map function onto the array. | ||
* | ||
* \param fmap The transformation function T -> U. | ||
* | ||
* \tparam F The type of the mutation function. | ||
* | ||
* \tparam U The type of the returned array, inferred from the | ||
* return type of F. If overridden by the user, must be something | ||
* that is convertible from the return type of F. | ||
* | ||
* \note This function performs copy on write optimization. If | ||
* `fmap` returns an object of type `T`, and all elements of the | ||
* array are mapped to themselves, then the returned array will be | ||
* the same as the original, and reference counts of the elements in | ||
* the array will not be incremented. | ||
* | ||
* \return The transformed array. | ||
*/ | ||
template <typename F, typename U = std::invoke_result_t<F, T>> | ||
Array<U> Map(F fmap) const { | ||
return Array<U>(MapHelper(data_, fmap)); | ||
} | ||
|
||
/*! | ||
* \brief Helper function to apply fmutate to mutate an array. | ||
* \param fmutate The transformation function T -> T. | ||
* \tparam F the type of the mutation function. | ||
* \note This function performs copy on write optimization. | ||
*/ | ||
template <typename F> | ||
template <typename F, typename = std::enable_if_t<std::is_same_v<T, std::invoke_result_t<F, T>>>> | ||
void MutateByApply(F fmutate) { | ||
if (data_ == nullptr) { | ||
return; | ||
} | ||
struct StackFrame { | ||
ArrayNode* p; | ||
ObjectRef* itr; | ||
int64_t i; | ||
int64_t size; | ||
}; | ||
std::unique_ptr<StackFrame> s = std::make_unique<StackFrame>(); | ||
s->p = GetArrayNode(); | ||
s->itr = s->p->MutableBegin(); | ||
s->i = 0; | ||
s->size = s->p->size_; | ||
if (!data_.unique()) { | ||
// Loop invariant: keeps iterating when | ||
// 1) data is not unique | ||
// 2) no elements are actually mutated yet | ||
for (; s->i < s->size; ++s->i, ++s->itr) { | ||
T new_elem = fmutate(DowncastNoCheck<T>(*s->itr)); | ||
// do nothing when there is no mutation | ||
if (new_elem.same_as(*s->itr)) { | ||
continue; | ||
} | ||
// loop invariant breaks when the first real mutation happens | ||
// we copy the elements into a new unique array | ||
ObjectPtr<ArrayNode> copy = ArrayNode::CopyFrom(s->p->capacity_, s->p); | ||
s->itr = copy->MutableBegin() + (s->i++); | ||
*s->itr++ = std::move(new_elem); | ||
data_ = std::move(copy); | ||
// make sure `data_` is unique and break | ||
break; | ||
} | ||
} | ||
// when execution comes to this line, it is guaranteed that either | ||
// 1) i == size | ||
// or 2) data_.unique() is true | ||
for (; s->i < s->size; ++s->i, ++s->itr) { | ||
*s->itr = std::move(fmutate(std::move(DowncastNoCheck<T>(std::move(*s->itr))))); | ||
} | ||
data_ = MapHelper(std::move(data_), fmutate); | ||
} | ||
|
||
/*! | ||
|
@@ -706,6 +710,118 @@ class Array : public ObjectRef { | |
} | ||
return static_cast<ArrayNode*>(data_.get()); | ||
} | ||
|
||
/*! \brief Helper method for mutate/map | ||
* | ||
* A helper function used internally by both `Array::Map` and | ||
* `Array::MutateInPlace`. Given an array of data, apply the | ||
* mapping function to each element, returning the collected array. | ||
* Applies both mutate-in-place and copy-on-write optimizations, if | ||
* possible. | ||
* | ||
* \param data A pointer to the ArrayNode containing input data. | ||
* Passed by value to allow for mutate-in-place optimizations. | ||
* | ||
* \param fmap The mapping function | ||
* | ||
* \tparam F The type of the mutation function. | ||
* | ||
* \tparam U The output type of the mutation function. Inferred | ||
* from the callable type given. Must inherit from ObjectRef. | ||
* | ||
* \return The mapped array. Depending on whether mutate-in-place | ||
* or copy-on-write optimizations were applicable, may be the same | ||
* underlying array as the `data` parameter. | ||
*/ | ||
template <typename F, typename U = std::invoke_result_t<F, T>> | ||
static ObjectPtr<Object> MapHelper(ObjectPtr<Object> data, F fmap) { | ||
if (data == nullptr) { | ||
return nullptr; | ||
} | ||
|
||
ICHECK(data->IsInstance<ArrayNode>()); | ||
|
||
constexpr bool is_same_output_type = std::is_same_v<T, U>; | ||
|
||
if constexpr (is_same_output_type) { | ||
if (data.unique()) { | ||
// Mutate-in-place path. Only allowed if the output type U is | ||
// the same as type T, we have a mutable this*, and there are | ||
// no other shared copies of the array. | ||
auto arr = static_cast<ArrayNode*>(data.get()); | ||
for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) { | ||
T mapped = fmap(DowncastNoCheck<T>(std::move(*it))); | ||
*it = std::move(mapped); | ||
} | ||
return data; | ||
} | ||
} | ||
|
||
constexpr bool compatible_types = is_valid_iterator_v<T, U*> || is_valid_iterator_v<U, T*>; | ||
|
||
ObjectPtr<ArrayNode> output = nullptr; | ||
auto arr = static_cast<ArrayNode*>(data.get()); | ||
|
||
auto it = arr->begin(); | ||
if constexpr (compatible_types) { | ||
// Copy-on-write path, if the output Array<U> might be | ||
// represented by the same underlying array as the existing | ||
// Array<T>. Typically, this is for functions that map `T` to | ||
// `T`, but can also apply to functions that map `T` to | ||
// `Optional<T>`, or that map `T` to a subclass or superclass of | ||
// `T`. | ||
bool all_identical = true; | ||
for (; it != arr->end(); it++) { | ||
U mapped = fmap(DowncastNoCheck<T>(*it)); | ||
if (!mapped.same_as(*it)) { | ||
// At least one mapped element is different than the | ||
// original. Therefore, prepare the output array, | ||
// consisting of any previous elements that had mapped to | ||
// themselves (if any), and the element that didn't map to | ||
// itself. | ||
all_identical = false; | ||
output = ArrayNode::CreateRepeated(arr->size(), U()); | ||
output->InitRange(0, arr->begin(), it); | ||
output->SetItem(it - arr->begin(), std::move(mapped)); | ||
it++; | ||
break; | ||
} | ||
} | ||
if (all_identical) { | ||
return data; | ||
} | ||
} else { | ||
// Path for incompatible types. The constexpr check for | ||
// compatible types isn't strictly necessary, as the first | ||
// mapped.same_as(*it) would return false, but we might as well | ||
// avoid it altogether. | ||
output = ArrayNode::CreateRepeated(arr->size(), U()); | ||
} | ||
|
||
// Normal path for incompatible types, or post-copy path for | ||
// copy-on-write instances. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What will be left over on the copy-on-write instance? Will there be some items that are incompatible? How are those guaranteed to be at the end? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If we have compatible types, and we've reached this point, we've found at least one element for which the
It's entirely possible, either at compile-time or at runtime. For example, I could have an If a type is incompatible at runtime, then it will also fail the
Incompatible items may occur at any point in the mapped output, even at the very first iteration. In that case, the commands executed in the conditional on // Same as the else branch on `compatible_types`
output = ArrayNode::CreateRepeated(arr->size(), U());
// For the first iteration, it is `arr->begin()`, so this would be an
// empty range [begin, begin), nothing is initialized, and this
// statement has no effect.
output->InitRange(0, arr->begin(), it);
// The newly mapped item is stored to the first location of the output.
output->SetItem(it - arr->begin(), std::move(mapped));
// The loop increment that would have happened
it++;
// `it` now points to the second element of the input, and we have one
// mapped element in the output. We're now ready to start the second
// loop, just at the second iteration instead of the first. Essentially, we only need to check for identical return values up until we find a single non-identical element, at which point we know that we can't avoid the copy anyways. But once we reach the first non-identical value, we don't need to repeat the function calls up to that point, because we know that everything is either identical (and can therefore be copied from the input) or is non-identical (is which case it is the first such non-identical value). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for this explanation! Maybe it would be helpful to others as well to summarize this in the comment block on 796... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No problem, and updated! |
||
// | ||
// If the types are incompatible, then at this point `output` is | ||
// empty, and `it` points to the first element of the input. | ||
// | ||
// If the types were compatible, then at this point `output` | ||
// contains zero or more elements that mapped to themselves | ||
// followed by the first element that does not map to itself, and | ||
// `it` points to the element just after the first element that | ||
// does not map to itself. Because at least one element has been | ||
// changed, we no longer have the opportunity to avoid a copy, so | ||
// we don't need to check the result. | ||
// | ||
// In both cases, `it` points to the next element to be processed, | ||
// so we can either start or resume the iteration from that point, | ||
// with no further checks on the result. | ||
for (; it != arr->end(); it++) { | ||
U mapped = fmap(DowncastNoCheck<T>(*it)); | ||
output->SetItem(it - arr->begin(), std::move(mapped)); | ||
} | ||
|
||
return output; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. General comment - Can you add a unit test to exercise the edge cases in MapHelper? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Certainly, and thank you for pointing that out! There are some existing tests in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tests added for each of the compatible types, to validate that copies are avoided, and to ensure correct fail-through behavior when a copy is required. A double thanks for requesting it, as it also caught a type conversion error that I had missed. |
||
}; | ||
|
||
/*! | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was curious if we could unify and migrate
MutateByApply
intoMap
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly, and that would allow for avoiding copies in a few additional cases (e.g. map from
T
toOptional<T>
, or to a superclass ofT
) that aren't currently handled. I'll take a quick stab at it and see if I can unify the two.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you on the suggestion, and it ended up being much cleaner that way. Both
Map
andMutateByApply
are now implemented in terms of the same underlying helper function. The helper function applies both the mutate-in-place and copy-on-write optimizations, withif constexpr
type checks to avoid attempting the optimization if they wouldn't be possible.