Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] tir.transform.StorageFlatten refactor #9091

Merged
merged 21 commits into from
Oct 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
e2afbbc
[TE] Improved flexibility of ArgBinder::BindDLTensor
Lunderberg Sep 22, 2021
806eb13
[TIR] Exposed ElemOffset as a member function of BufferNode.
Lunderberg Sep 22, 2021
7c68085
[TE] Pulled shape determination out of StorageFlattener
Lunderberg Sep 21, 2021
6599278
[TE] Refactor stride calculation out of StorageFlattener
Lunderberg Sep 22, 2021
0716e01
[TE] Refactor thread scope propagation out of StorageFlattener.
Lunderberg Sep 22, 2021
9fa935e
[TE] Refactor buffer bind mapping out of StorageFlattener.
Lunderberg Sep 22, 2021
1c729e6
[TIR] Removed checks on buffer->shape.size()
Lunderberg Sep 23, 2021
40ab5a4
[TIR] Relaxed check on a bufferview's striding.
Lunderberg Sep 23, 2021
f361263
[TIR] Fixed StorageFlatten test for shape_legalize.
Lunderberg Sep 23, 2021
5a79bc3
[TIR] Assigned storage scope
Lunderberg Sep 24, 2021
31c4269
Updated ICHECK_EQ to CHECK_EQ for a test that depends on user-provided
Lunderberg Sep 24, 2021
7d3f1ae
Added comments in storage_flatten.cc, indicating why buffer_bind_scope
Lunderberg Sep 24, 2021
506c391
Updated comment with a few examples of where compact buffers are
Lunderberg Sep 24, 2021
054cd72
Updated following @csullivan's comments.
Lunderberg Sep 27, 2021
acf1c3a
Added fuzzy mapping to the BufferShapeLegalize.
Lunderberg Sep 28, 2021
0b06e98
Updated BufferShapeLegalize, asserts need to be inside the buffer_bin…
Lunderberg Sep 28, 2021
87a5d48
Pulled all shape-dependent behavior into BufferShapeLegalize.
Lunderberg Sep 28, 2021
78af077
Added another pass to remove verifiable assert statements.
Lunderberg Sep 29, 2021
3d3ec42
Minor cleanup
Lunderberg Sep 29, 2021
ddfc56f
Updated to handle BufferRealizeNode with no defined bounds.
Lunderberg Sep 29, 2021
ee2dc22
Updated to be less aggressive when checking AssertStmt
Lunderberg Sep 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ class BufferNode : public Object {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
}

/*! \brief Determine the offset in the buffer of the given index.
*
* Returns the buffer offset, in number of elements of type dtype,
* without adjusting for number of lanes. (e.g. The number of
* float16x4 elements in a buffer of type float16x4.)
*/
PrimExpr ElemOffset(Array<PrimExpr> index) const;

static constexpr const char* _type_key = "tir.Buffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand Down
24 changes: 12 additions & 12 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,41 +246,41 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) {
// The buffer offset in convention of number of elements of
// original data ignoring number of lanes.
// We also perform optimization to simplify the indexing expression.
inline PrimExpr ElemOffset(const BufferNode* n, Array<PrimExpr> index) {
PrimExpr base = n->elem_offset;
PrimExpr BufferNode::ElemOffset(Array<PrimExpr> index) const {
PrimExpr base = this->elem_offset;
arith::Analyzer ana;
if (n->strides.size() == 0) {
if (this->strides.size() == 0) {
// Scalar case
if (n->shape.size() == 0 && index.size() == 1) {
if (this->shape.size() == 0 && index.size() == 1) {
auto is_int = index[0].as<IntImmNode>();
ICHECK(is_int && is_int->value == 0);
base = base + index[0];
} else {
ICHECK_EQ(n->shape.size(), index.size());
ICHECK_EQ(this->shape.size(), index.size());
if (index.size() > 0) {
PrimExpr offset = index[0];
for (size_t i = 1; i < index.size(); ++i) {
offset = MergeMulMod(&ana, offset * n->shape[i] + index[i]);
offset = MergeMulMod(&ana, offset * this->shape[i] + index[i]);
}
base = base + offset;
}
}
} else {
ICHECK_EQ(n->strides.size(), index.size());
ICHECK_EQ(this->strides.size(), index.size());
if (is_zero(base)) {
base = MergeMulMod(&ana, index[0] * n->strides[0]);
base = MergeMulMod(&ana, index[0] * this->strides[0]);
} else {
base = MergeMulMod(&ana, base + index[0] * n->strides[0]);
base = MergeMulMod(&ana, base + index[0] * this->strides[0]);
}
for (size_t i = 1; i < index.size(); ++i) {
base = MergeMulMod(&ana, base + index[i] * n->strides[i]);
base = MergeMulMod(&ana, base + index[i] * this->strides[i]);
}
}
return base;
}

inline PrimExpr BufferOffset(const BufferNode* n, Array<PrimExpr> index, DataType dtype) {
PrimExpr offset = ElemOffset(n, index);
PrimExpr offset = n->ElemOffset(index);
if (n->dtype.lanes() != 1) {
offset = offset * make_const(offset.dtype(), dtype.lanes());
}
Expand Down Expand Up @@ -353,7 +353,7 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
ICHECK(n != nullptr);
arith::Analyzer ana;
begins = SimplifyArray(&ana, begins);
PrimExpr elem_offset = ana.Simplify(ElemOffset(n, begins));
PrimExpr elem_offset = ana.Simplify(n->ElemOffset(begins));
Array<PrimExpr> strides = n->strides;
if (strides.size() == 0) {
bool can_relax = true;
Expand Down
25 changes: 15 additions & 10 deletions src/tir/transforms/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type));
init_nest_.emplace_back(
LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop));
PrimExpr is_null = Call(DataType::Bool(1), builtin::isnullptr(), {v_strides});
PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {v_strides});
if (buffer->strides.size() == 0) {
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
Expand All @@ -226,7 +226,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); },
const_true(1), conds),
stride_msg, Evaluate(0));
check = IfThenElse(Not(is_null), check, Stmt());
check = IfThenElse(Not(v_strides_is_null), check, Stmt());
asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
Expand All @@ -239,24 +239,29 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
PrimExpr value =
cast(buffer->shape[k].dtype(),
Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1)));
value = tvm::if_then_else(is_null, stride, value);
value = tvm::if_then_else(v_strides_is_null, stride, value);
value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
Bind_(buffer->strides[k], value, field_name.str(), true);
stride = analyzer_.Simplify(stride * buffer->shape[k]);
}
} else {
std::ostringstream stride_null_err_msg;
stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
asserts_.emplace_back(
AssertStmt(Not(is_null), tvm::tir::StringImm(stride_null_err_msg.str()), nop));
PrimExpr stride_from_shape = 1;

for (size_t k = 0; k < buffer->strides.size(); ++k) {
for (int k = buffer->strides.size() - 1; k >= 0; k--) {
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';

PrimExpr explicit_stride =
cast(buffer->shape[k].dtype(),
Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1)));

Bind_(buffer->strides[k],
cast(buffer->shape[k].dtype(),
Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))),
tvm::if_then_else(v_strides_is_null, stride_from_shape, explicit_stride),
field_name.str(), true);

stride_from_shape *=
cast(buffer->shape[k].dtype(),
Load(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1)));
}
}
// Byte_offset field.
Expand Down
Loading