-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ARITH] Introduce iterator (quasi)affine map detection. #6667
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,277 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/arith/iter_affine_map.h | ||
* \brief Iterator quasi-affine mapping patterns. | ||
* | ||
* This file defines a collection of mapping patterns | ||
* maps a collection of independent iterators to another | ||
* collection of independent iterators. | ||
* | ||
* There are two main kinds of mapping patterns: | ||
* | ||
* - Fuse: fuse a collection of iterators into a single one | ||
* | ||
* domain(x0) = [0, 4), domain(x1) = [0, 3), domain(x2) = [0, 2) | ||
* fuse(x0, x1, x2): y = x2 * 12 + x1 * 4 + x0 | ||
* domain(y) = [0, 24) | ||
* | ||
* - Split: split an iterator into multiple ones | ||
* | ||
* domain(x) = [0, 24) | ||
* split(x, 3, 12): [y0, y1, y2] = [x % 3, (x % 12) / 3, x / 12] | ||
* domain(y0) = [0, 3), domain(y1) = [0, 4), domain(y2) = [0, 2) | ||
* | ||
* We use the name "(quasi)affine" to be consistent with | ||
* the terminology used in the polyhedral compilation. | ||
* Notably, fuse is an affine transformation, | ||
* while split corresponds to additional floordiv/mod operations | ||
* that can appear in quasi-affine transformations. | ||
*/ | ||
#ifndef TVM_ARITH_ITER_AFFINE_MAP_H_ | ||
#define TVM_ARITH_ITER_AFFINE_MAP_H_ | ||
|
||
#include <tvm/ir/expr.h> | ||
|
||
namespace tvm { | ||
namespace arith { | ||
|
||
/*! | ||
* \brief Base class of all iter map expressions. | ||
* | ||
* An IterMapExpr is a special expression to store | ||
* the result of IterMapDetection. | ||
* It should not appear in a legal TIR PrimFunc. | ||
*/ | ||
class IterMapExprNode : public PrimExprNode { | ||
public: | ||
// overrides | ||
void VisitAttrs(tvm::AttrVisitor* v) {} | ||
|
||
static constexpr const char* _type_key = "arith.IterMapExpr"; | ||
static constexpr const uint32_t _type_child_slots = 3; | ||
TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode); | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference to IterMapExprNode. | ||
* \sa IterMapExprNode | ||
*/ | ||
class IterMapExpr : public PrimExpr { | ||
public: | ||
TVM_DEFINE_OBJECT_REF_METHODS(IterMapExpr, PrimExpr, IterMapExprNode); | ||
}; | ||
|
||
/*! | ||
* \brief Mark the source as an iterator in [0, extent). | ||
* | ||
* IterMark is used to mark source expression as a valid | ||
* iterator to make future analysis easy. | ||
*/ | ||
class IterMarkNode : public Object { | ||
public: | ||
/*! | ||
* \brief The source expression, can either be | ||
* a IterSumExpr or a Var. | ||
*/ | ||
PrimExpr source; | ||
/*! | ||
* \brief The extent of the iteration. | ||
*/ | ||
PrimExpr extent; | ||
|
||
// overrides | ||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
v->Visit("source", &source); | ||
v->Visit("extent", &extent); | ||
} | ||
|
||
bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const { | ||
equal->MarkGraphNode(); | ||
return equal(source, other->source) && equal(extent, other->extent); | ||
} | ||
|
||
void SHashReduce(SHashReducer hash_reduce) const { | ||
hash_reduce->MarkGraphNode(); | ||
hash_reduce(source); | ||
hash_reduce(extent); | ||
} | ||
|
||
static constexpr const bool _type_has_method_sequal_reduce = true; | ||
static constexpr const bool _type_has_method_shash_reduce = true; | ||
static constexpr const char* _type_key = "arith.IterMark"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(IterMarkNode, Object); | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference to IterMarkExprNode. | ||
* \sa IterMarkExprNode | ||
*/ | ||
class IterMark : public ObjectRef { | ||
public: | ||
/*! | ||
* \brief constructor. | ||
* \param source The source expression. | ||
* \param extent The extent of the iterator. | ||
*/ | ||
TVM_DLL IterMark(PrimExpr source, PrimExpr extent); | ||
|
||
TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode); | ||
}; | ||
|
||
/*! | ||
* \brief Split of an iterator. | ||
* | ||
* result = floormod(floordiv(source, lower_factor), extent) * scale | ||
*/ | ||
class IterSplitExprNode : public IterMapExprNode { | ||
public: | ||
/*! \brief The source marked iterator. */ | ||
IterMark source; | ||
/*! \brief The lower factor to split the source. */ | ||
PrimExpr lower_factor; | ||
/*! \brief The extent of the split. */ | ||
PrimExpr extent; | ||
/*! \brief Additional scale. */ | ||
PrimExpr scale; | ||
|
||
// overrides | ||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
v->Visit("source", &source); | ||
v->Visit("lower_factor", &lower_factor); | ||
v->Visit("extent", &extent); | ||
v->Visit("scale", &scale); | ||
} | ||
|
||
bool SEqualReduce(const IterSplitExprNode* other, SEqualReducer equal) const { | ||
return equal(source, other->source) && equal(lower_factor, other->lower_factor) && | ||
equal(extent, other->extent) && equal(scale, other->scale); | ||
} | ||
|
||
void SHashReduce(SHashReducer hash_reduce) const { | ||
hash_reduce(source); | ||
hash_reduce(lower_factor); | ||
hash_reduce(extent); | ||
hash_reduce(scale); | ||
} | ||
|
||
static constexpr const char* _type_key = "arith.IterSplitExpr"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(IterSplitExprNode, IterMapExprNode); | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference to IterSplitExprNode. | ||
* \sa IterSplitExprNode | ||
*/ | ||
class IterSplitExpr : public IterMapExpr { | ||
public: | ||
/*! | ||
* \brief constructor from just source. | ||
* \param source The source expression. | ||
*/ | ||
TVM_DLL explicit IterSplitExpr(IterMark source); | ||
/*! | ||
* \brief constructor | ||
* \param source The source expression. | ||
* \param lower_factor The lower factor to split the source. | ||
* \param extent The extent of the split. | ||
* \param scale The additional scaling factor. | ||
*/ | ||
TVM_DLL explicit IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent, | ||
PrimExpr scale); | ||
|
||
TVM_DEFINE_OBJECT_REF_METHODS(IterSplitExpr, IterMapExpr, IterSplitExprNode); | ||
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSplitExprNode); | ||
}; | ||
|
||
/*! | ||
* \brief Fuse multiple iterators by summing them with scaling. | ||
* | ||
* result = sum(args) + base | ||
*/ | ||
class IterSumExprNode : public IterMapExprNode { | ||
public: | ||
/*! \brief The args to the sum. */ | ||
Array<IterSplitExpr> args; | ||
/*! \brief The base offset. */ | ||
PrimExpr base; | ||
|
||
// overrides | ||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
v->Visit("args", &args); | ||
v->Visit("base", &base); | ||
} | ||
|
||
bool SEqualReduce(const IterSumExprNode* other, SEqualReducer equal) const { | ||
return equal(args, other->args) && equal(base, other->base); | ||
} | ||
|
||
void SHashReduce(SHashReducer hash_reduce) const { | ||
hash_reduce(args); | ||
hash_reduce(base); | ||
} | ||
|
||
static constexpr const char* _type_key = "arith.IterSumExpr"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(IterSumExprNode, IterMapExprNode); | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference to IterSumExprNode. | ||
* \sa IterSumExprNode | ||
*/ | ||
class IterSumExpr : public IterMapExpr { | ||
public: | ||
/*! | ||
* \brief constructor. | ||
* \param args The args to the sum. | ||
* \param base The base offset. | ||
*/ | ||
TVM_DLL IterSumExpr(Array<IterSplitExpr> args, PrimExpr base); | ||
|
||
TVM_DEFINE_OBJECT_REF_METHODS(IterSumExpr, IterMapExpr, IterSumExprNode); | ||
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode); | ||
}; | ||
|
||
/*! | ||
* \brief Detect if indices can be written as | ||
* | ||
* [y_0 + c_0, y_1 + c_1, ..., y_n + c_n] | ||
* | ||
* Here y = some-quasi-affine-iter-map(input_iters) | ||
* and c are symbolic constants. | ||
* | ||
* We also requires that y_i and y_j to be independent for i != j. | ||
* | ||
* For returned value rv, the following is always true: | ||
* - rv[i]->args.size() <=1: only one iterator per element. | ||
* | ||
* \param indices The indices to detect pattern for. | ||
* \param input_iters Map from variable to iterator's range. | ||
* \param analyzer Analyzer used to get context information. | ||
* | ||
* \return The detected pattern if a match exists, | ||
* otherwise return an empty array. | ||
*/ | ||
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters, | ||
arith::Analyzer* analyzer); | ||
|
||
} // namespace arith | ||
} // namespace tvm | ||
#endif // TVM_ARITH_ITER_AFFINE_MAP_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
""" Iterator (quasi)affine mapping patterns.""" | ||
import tvm._ffi | ||
from tvm.runtime import Object | ||
from tvm.ir import PrimExpr | ||
from . import _ffi_api | ||
|
||
|
||
class IterMapExpr(PrimExpr): | ||
"""Base class of all IterMap expressions.""" | ||
|
||
|
||
@tvm._ffi.register_object("arith.IterMark") | ||
class IterMark(Object): | ||
"""Mark the source as an iterator in [0, extent). | ||
Parameters | ||
---------- | ||
source : PrimExpr. | ||
The source expression. | ||
extent : PrimExpr | ||
The extent of the iterator. | ||
""" | ||
|
||
def __init__(self, source, extent): | ||
self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent) | ||
|
||
|
||
@tvm._ffi.register_object("arith.IterSplitExpr") | ||
class IterSplitExpr(IterMapExpr): | ||
"""Split of an iterator. | ||
result = floormod(floordiv(source, lower_factor), extent) * scale | ||
Parameters | ||
---------- | ||
source : IterMark | ||
The source marked iterator. | ||
lower_factor : PrimExpr | ||
The lower factor to split the domain. | ||
extent : PrimExpr | ||
The extent of the split. | ||
scale : PrimExpr | ||
Additional scale to the split. | ||
""" | ||
|
||
def __init__(self, source, lower_factor, extent, scale): | ||
self.__init_handle_by_constructor__( | ||
_ffi_api.IterSplitExpr, source, lower_factor, extent, scale | ||
) | ||
|
||
|
||
@tvm._ffi.register_object("arith.IterSumExpr") | ||
class IterSumExpr(IterMapExpr): | ||
"""Fuse multiple iterators by summing them with scaling. | ||
result = sum(args) + base | ||
Parameters | ||
---------- | ||
args : List[IterSplitExpr] | ||
The input to the sum expression. | ||
base : PrimExpr | ||
The base offset. | ||
""" | ||
|
||
def __init__(self, args, base): | ||
self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) | ||
|
||
|
||
def detect_iter_map(indices, input_iters): | ||
"""Detect if indices can be written mapped iters from input_iters. | ||
Parameters | ||
---------- | ||
indices : List[PrimExpr] | ||
The input indices. | ||
input_iters : Map[Var, Range] | ||
The domain of each input iterators. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May be good to add Returns in comment here. |
||
Returns | ||
------- | ||
results : List[IterSumExpr] | ||
The iter map matching result. | ||
Empty array if no match can be found. | ||
""" | ||
return _ffi_api.DetectIterMap(indices, input_iters) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent with function arguments