Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Mar 1, 2021
1 parent 476d0c3 commit a2ec0a7
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 98 deletions.
14 changes: 5 additions & 9 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1064,9 +1064,8 @@ class BlockNode : public StmtNode {
// Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars
return equal.DefEqual(iter_vars, other->iter_vars) &&
equal(alloc_buffers, other->alloc_buffers) &&
equal(match_buffers, other->match_buffers) &&
equal(reads, other->reads) && equal(writes, other->writes) &&
equal(body, other->body) && equal(init, other->init) &&
equal(match_buffers, other->match_buffers) && equal(reads, other->reads) &&
equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) &&
equal(annotations, other->annotations);
}

Expand All @@ -1091,11 +1090,8 @@ class BlockNode : public StmtNode {
*/
class Block : public Stmt {
public:
TVM_DLL explicit Block(Array<IterVar> iter_vars,
Array<BufferRegion> reads,
Array<BufferRegion> writes,
String name_hint,
Stmt body,
TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
Array<BufferRegion> writes, String name_hint, Stmt body,
Optional<Stmt> init = NullOpt,
Array<Buffer> alloc_buffers = Array<Buffer>(),
Array<MatchBufferRegion> match_buffers = Array<MatchBufferRegion>(),
Expand Down Expand Up @@ -1129,7 +1125,7 @@ class BlockRealizeNode : public StmtNode {

bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const {
return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) &&
equal(block, other->block);
equal(block, other->block);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand Down
160 changes: 75 additions & 85 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -621,22 +621,21 @@ TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array<R
TVM_REGISTER_NODE_TYPE(BufferRegionNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferRegionNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferRegionNode*>(node.get());
p->stream << op->buffer->name;
p->stream << "[";
for (size_t i = 0; i < op->region.size(); ++i) {
const auto& range = op->region[i];
p->Print(range->min);
if (!is_one(range->extent)) {
p->stream << ":";
p->Print(range->min + range->extent);
}
if (i != op->region.size() - 1) p->stream << ", ";
}
p->stream << "]";

});
.set_dispatch<BufferRegionNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferRegionNode*>(node.get());
p->stream << op->buffer->name;
p->stream << "[";
for (size_t i = 0; i < op->region.size(); ++i) {
const auto& range = op->region[i];
p->Print(range->min);
if (!is_one(range->extent)) {
p->stream << ":";
p->Print(range->min + range->extent);
}
if (i != op->region.size() - 1) p->stream << ", ";
}
p->stream << "]";
});

// MatchBufferRegion
MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) {
Expand All @@ -653,24 +652,18 @@ TVM_REGISTER_GLOBAL("tir.MatchBufferRegion").set_body_typed([](Buffer buffer, Bu
TVM_REGISTER_NODE_TYPE(MatchBufferRegionNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MatchBufferRegionNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MatchBufferRegionNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer->name << " = match_buffer_region(";
p->Print(op->source);
p->stream << ")\n";
});
.set_dispatch<MatchBufferRegionNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MatchBufferRegionNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer->name << " = match_buffer_region(";
p->Print(op->source);
p->stream << ")\n";
});

// Block
Block::Block(Array<IterVar> iter_vars,
Array<BufferRegion> reads,
Array<BufferRegion> writes,
String name_hint,
Stmt body,
Optional<Stmt> init,
Array<Buffer> alloc_buffers,
Array<MatchBufferRegion> match_buffers,
Map<String, ObjectRef> annotations,
Block::Block(Array<IterVar> iter_vars, Array<BufferRegion> reads, Array<BufferRegion> writes,
String name_hint, Stmt body, Optional<Stmt> init, Array<Buffer> alloc_buffers,
Array<MatchBufferRegion> match_buffers, Map<String, ObjectRef> annotations,
Span span) {
ObjectPtr<BlockNode> node = make_object<BlockNode>();
node->iter_vars = std::move(iter_vars);
Expand All @@ -689,11 +682,10 @@ Block::Block(Array<IterVar> iter_vars,
TVM_REGISTER_GLOBAL("tir.Block")
.set_body_typed([](Array<IterVar> iter_vars, Array<BufferRegion> reads,
Array<BufferRegion> writes, String name_hint, Stmt body, Optional<Stmt> init,
Array<Buffer> alloc_buffers,
Array<MatchBufferRegion> match_buffers, Map<String, ObjectRef> annotations,
Span span) {
return Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers,
match_buffers, annotations, span);
Array<Buffer> alloc_buffers, Array<MatchBufferRegion> match_buffers,
Map<String, ObjectRef> annotations, Span span) {
return Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers,
annotations, span);
});

TVM_REGISTER_NODE_TYPE(BlockNode);
Expand Down Expand Up @@ -750,33 +742,31 @@ void PrintBlockBody(const BlockNode* op, ReprPrinter* p) {
}
// Print body
p->Print(op->body);

}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BlockNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BlockNode*>(node.get());
p->PrintIndent();
PrintBlockTitle(op, p);
p->stream << "{\n";
p->indent += 2;
.set_dispatch<BlockNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BlockNode*>(node.get());
p->PrintIndent();
PrintBlockTitle(op, p);
p->stream << "{\n";
p->indent += 2;

// Print block elements (e.g. reads/writes, etc)
PrintBlockSignature(op, p);
// Print block init and body
PrintBlockBody(op, p);
// Print block elements (e.g. reads/writes, etc)
PrintBlockSignature(op, p);
// Print block init and body
PrintBlockBody(op, p);

p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
});
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
});

// BlockRealize
BlockRealize::BlockRealize(Array<PrimExpr> values, PrimExpr predicate, Block block, Span span) {
CHECK_EQ(block->iter_vars.size(), values.size())
<< "ValueError: BlockRealize needs to have the same number of iter_vars and binding values";
CHECK(predicate.dtype().is_bool())
<< "TypeError: Expect Block.predicate to be a bool expression";
CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to be a bool expression";
ObjectPtr<BlockRealizeNode> node = make_object<BlockRealizeNode>();
node->iter_values = std::move(values);
node->predicate = std::move(predicate);
Expand All @@ -793,39 +783,39 @@ TVM_REGISTER_GLOBAL("tir.BlockRealize")
TVM_REGISTER_NODE_TYPE(BlockRealizeNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BlockRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BlockRealizeNode*>(node.get());
auto* block_op = op->block.get();
p->PrintIndent();
PrintBlockTitle(block_op, p);
p->stream << "{\n";
p->indent += 2;
.set_dispatch<BlockRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BlockRealizeNode*>(node.get());
auto* block_op = op->block.get();
p->PrintIndent();
PrintBlockTitle(block_op, p);
p->stream << "{\n";
p->indent += 2;

// Print binding iter_values
for (size_t i = 0; i < block_op->iter_vars.size(); ++i) {
p->PrintIndent();
p->stream << "bind(";
p->Print(block_op->iter_vars[i]->var);
p->stream << ", ";
p->Print(op->iter_values[i]);
p->stream << ")\n";
}
// Print predicate
if (!is_one(op->predicate)) {
p->PrintIndent();
p->stream << "where(";
p->Print(op->predicate);
p->stream << ")\n";
}
// Print block elements (e.g. reads/writes, etc)
PrintBlockSignature(block_op, p);
// Print block init and body
PrintBlockBody(block_op, p);
// Print binding iter_values
for (size_t i = 0; i < block_op->iter_vars.size(); ++i) {
p->PrintIndent();
p->stream << "bind(";
p->Print(block_op->iter_vars[i]->var);
p->stream << ", ";
p->Print(op->iter_values[i]);
p->stream << ")\n";
}
// Print predicate
if (!is_one(op->predicate)) {
p->PrintIndent();
p->stream << "where(";
p->Print(op->predicate);
p->stream << ")\n";
}
// Print block elements (e.g. reads/writes, etc)
PrintBlockSignature(block_op, p);
// Print block init and body
PrintBlockBody(block_op, p);

p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
});
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
});

PrimExpr TypeAnnotation(DataType dtype, Span span) {
static auto op = Op::Get("tir.type_annotation");
Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/ir_functor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ TEST(IRF, StmtVisitor) {
MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region);

// construct block and block_realize
Block block = Block({}, {buffer_region}, {buffer_region}, "block", body, body, {},
{match_buffer_region});
Block block =
Block({}, {buffer_region}, {buffer_region}, "block", body, body, {}, {match_buffer_region});
Stmt block_realize = BlockRealize({}, const_true(), block);

v.count = 0;
Expand Down Expand Up @@ -258,8 +258,8 @@ TEST(IRF, StmtMutator) {
BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)});
MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region);
// construct block and block_realize
Block block = Block({}, {buffer_region}, {buffer_region}, "block", body, body, {},
{match_buffer_region});
Block block =
Block({}, {buffer_region}, {buffer_region}, "block", body, body, {}, {match_buffer_region});
Stmt block_realize = BlockRealize({}, const_true(), block);
body = v(std::move(block_realize));
// the body should be changed
Expand Down

0 comments on commit a2ec0a7

Please sign in to comment.