Skip to content

Commit

Permalink
[TVMScript] T.allocate with T.decl_buffer syntax sugar for TVMScr…
Browse files Browse the repository at this point in the history
…ipt printer (#13813)

This PR implements the syntax sugar of `T.allocate` with `T.decl_buffer` for new TVMScript printer. This syntax sugar will skip the `T.allocate`, when its body is a matched `T.decl_buffer`, and the `Var` defined in `T.allocate` is only used in that `T.decl_buffer`. For example, it will change

```python
buffer_data = T.allocate([128, 128])
buffer = T.decl_buffer([128, 128], data=buffer_data)
```
into

```python
buffer = T.decl_buffer([128, 128])
```
but keep the following `T.allocate` unchanged:
```python
buffer_data = T.allocate([128, 128])
buffer_0 = T.decl_buffer([128, 128], data=buffer_data)
buffer_1 = T.decl_buffer([128, 128], data=buffer_data)
```
and
```python
buffer_data = T.allocate([128, 128])
buffer = T.decl_buffer([256, 256], data=buffer_data)
```
  • Loading branch information
cyx-6 authored Jan 20, 2023
1 parent 5e36ae3 commit 8f80e42
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 85 deletions.
43 changes: 1 addition & 42 deletions src/script/printer/tir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
* under the License.
*/
#include <tvm/runtime/device_api.h>
#include <tvm/tir/stmt_functor.h>

#include "./utils.h"

Expand Down Expand Up @@ -65,47 +64,7 @@ bool IsSimpleBuffer(const tir::Buffer& buf) {
}

int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) {
class OccurrenceCounter : public tir::StmtExprVisitor {
public:
int count = 0;
const tir::VarNode* v = nullptr;

void VisitExpr_(const tir::VarNode* op) final {
if (op == v) {
++count;
}
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::BufferStoreNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitExpr_(const tir::BufferLoadNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::DeclBufferNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitBuffer(const tir::BufferNode* buffer) {
VisitExpr(buffer->data);
for (const PrimExpr& shape_i : buffer->shape) {
VisitExpr(shape_i);
}
for (const PrimExpr& stride_i : buffer->strides) {
VisitExpr(stride_i);
}
VisitExpr(buffer->elem_offset);
}
};

OccurrenceCounter counter;
counter.v = v.get();
OccurrenceCounter counter(v.get());
counter(f->body);
for (const tir::Var& v : f->params) {
counter(v);
Expand Down
24 changes: 24 additions & 0 deletions src/script/printer/tir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,34 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}));
});

bool IsAllocateDeclBufferPattern(const tir::AllocateNode* allocate) {
const tir::Var& buffer_var = allocate->buffer_var;
if (const tir::DeclBufferNode* decl_buffer = allocate->body.as<tir::DeclBufferNode>()) {
const tir::Buffer& buffer = decl_buffer->buffer;
if (buffer_var.same_as(buffer->data) && allocate->dtype == buffer->dtype &&
tir::is_one(allocate->condition) && !allocate->annotations.size() &&
allocate->extents.size() == buffer->shape.size()) {
tir::ExprDeepEqual expr_equal;
for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
return false;
}
}
return true;
}
}
return false;
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Allocate>( //
"", [](tir::Allocate stmt, ObjectPath p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
OccurrenceCounter counter(stmt->buffer_var.get());
counter(stmt->body);
if (counter.count == 1 && IsAllocateDeclBufferPattern(stmt.get())) {
return d->AsDoc(stmt->body, p->Attr("body"));
}
String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var);
Array<ExprDoc> args;
Array<String> kwargs_keys;
Expand Down
45 changes: 45 additions & 0 deletions src/script/printer/tir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>

#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -220,6 +221,50 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<
ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
const IRDocsifier& d);

/*! \brief A Var occurrence counter visitor */
class OccurrenceCounter : public tir::StmtExprVisitor {
public:
/*! \brief The occurrence counter */
int count = 0;
/*! \brief The Var to count occurrence */
const tir::VarNode* v = nullptr;

void VisitExpr_(const tir::VarNode* op) final {
if (op == v) {
++count;
}
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::BufferStoreNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitExpr_(const tir::BufferLoadNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::DeclBufferNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitBuffer(const tir::BufferNode* buffer) {
VisitExpr(buffer->data);
for (const PrimExpr& shape_i : buffer->shape) {
VisitExpr(shape_i);
}
for (const PrimExpr& stride_i : buffer->strides) {
VisitExpr(stride_i);
}
VisitExpr(buffer->elem_offset);
}

explicit OccurrenceCounter(const tir::VarNode* var) { v = var; }
};

} // namespace printer
} // namespace script
} // namespace tvm
Expand Down
92 changes: 49 additions & 43 deletions tests/python/unittest/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import tir as T
from tvm.script.printer import default
import tvm.testing


@contextmanager
Expand Down Expand Up @@ -327,6 +328,53 @@ def test_allocate():
)


def test_allocate_with_decl_buffer_sugar():
with IRBuilder() as ib:
with T.allocate([128, 128], "float32") as buffer_data:
with T.decl_buffer([128, 128], "float32", data=buffer_data) as buffer:
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
with T.decl_buffer((128, 128)) as buffer:
T.evaluate(0)
""",
)


def test_allocate_with_decl_buffer_no_sugar_multi_usage():
with IRBuilder() as ib:
with T.allocate([128, 128], "float32") as buffer_data:
with T.decl_buffer([128, 128], "float32", data=buffer_data) as buffer:
T.evaluate(buffer_data)
obj = ib.get()
_assert_print(
obj,
"""
with T.allocate([128, 128], "float32", "global") as v:
buffer = T.decl_buffer((128, 128), data=v)
T.evaluate(v)
""",
)


def test_allocate_with_decl_buffer_no_sugar_mismatch():
with IRBuilder() as ib:
with T.allocate([128, 128], "float32") as buffer_data:
with T.decl_buffer([256, 256], "float32", data=buffer_data) as buffer:
T.evaluate(buffer_data)
obj = ib.get()
_assert_print(
obj,
"""
with T.allocate([128, 128], "float32", "global") as v:
buffer = T.decl_buffer((256, 256), data=v)
T.evaluate(v)
""",
)


def test_decl_buffer():
with IRBuilder() as ib:
with T.decl_buffer((10, 10), data=T.ptr("float32")):
Expand Down Expand Up @@ -686,46 +734,4 @@ def main():


if __name__ == "__main__":
test_prim_func()
test_prim_func_no_sugar_inlined_buffer()
test_prim_func_no_sugar_shared_buffer_data()
test_block_realize()
test_block()
test_buffer()
test_buffer_region()
test_buffer_load()
test_buffer_store()
test_match_buffer_region()
test_for()
test_let_stmt()
test_attr_stmt()
test_assert_stmt()
test_while()
test_allocate()
test_decl_buffer()
test_prefetch()
test_seq_stmt()
test_if_then_else()
test_evaluate()
test_buffer_realize()
test_var()
test_size_var()
test_iter_var()
test_string_imm()
test_cast()
test_binary_arith()
test_logical()
test_select()
test_ramp()
test_broadcast()
test_let_expr()
test_call()
test_comm_reducer()
test_any()
test_int_imm()
test_float_imm()
test_range()
test_prim_type()
test_pointer_type()
test_tuple_type()
test_remap()
tvm.testing.main()

0 comments on commit 8f80e42

Please sign in to comment.