Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Jun 30, 2021
1 parent ce4abcb commit b629c0b
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 118 deletions.
187 changes: 86 additions & 101 deletions src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,118 +43,109 @@ inline bool IsParam(const PrimFunc& func, const Var& param) {
[&](const Var& var) { return var.same_as(param); });
}

/**************** Specializer ****************/

/*! \brief Mutator to specialize function and remove const parameters */
class PrimFuncSpecializer : public StmtExprMutator {
public:
explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
explicit PrimFuncSpecializer(const VarMap& var_map) : var_map_(var_map) {}

static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
PrimFuncSpecializer specializer(var_map);
// Updating Buffer map
Map<Var, Buffer> buffer_map;
bool buffer_map_updated = false;
for (const auto& it : f->buffer_map) {
const Var& var = it.first;
const Buffer& buffer = it.second;
Buffer new_buffer = specializer.MutateBuffer(buffer);
buffer_map.Set(var, new_buffer);
if (!new_buffer.same_as(buffer)) {
buffer_map_updated = true;
specializer.buffer_map_[buffer] = new_buffer;
}
}

// Updating parmeters
Array<Var> params;
bool param_updated = false;
for (const auto& var : f->params) {
// Remove parmeters which has been specialized.
if (var_map.find(var) == var_map.end()) {
params.push_back(var);
} else {
param_updated = true;
}
}

PrimFuncNode* f_ptr = f.CopyOnWrite();
f_ptr->params = std::move(params);
f_ptr->buffer_map = std::move(buffer_map);
f_ptr->body = specializer(std::move(f_ptr->body));

// Updating attrs
if (f->attrs.defined()) {
auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
for (const auto& kv : attr_dict) {
const String& key = kv.first;
const ObjectRef& value = kv.second;
if (value->IsInstance<PrimExprNode>()) {
attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
}
}
// Updating function body
Stmt body = specializer(f->body);

if (param_updated || buffer_map_updated || !f->body.same_as(body)) {
PrimFuncNode* f_ptr = f.CopyOnWrite();
f_ptr->params = std::move(params);
f_ptr->buffer_map = std::move(buffer_map);
f_ptr->body = std::move(body);
}
return f;
}

private:
Stmt VisitStmt_(const BlockNode* op) final {
// Step.0. Define buffer mappings which is allocated inside the block
Array<Buffer> alloc_buffers = MutateArray(
op->alloc_buffers,
std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));

// Step.1. Recursively visit block body
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<BlockNode>();
ICHECK(op != nullptr);

Array<BufferRegion> reads = MutateArray(
op->reads,
std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
Array<BufferRegion> writes = MutateArray(
op->writes,
std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
Array<IterVar> block_vars = MutateArray(
op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
Optional<Stmt> init = NullOpt;
if (op->init.defined()) {
init = VisitStmt(op->init.value());
}
Stmt body = VisitStmt(op->body);

if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
init.same_as(op->init)) {
if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads)) {
return GetRef<Block>(op);
} else {
ObjectPtr<BlockNode> n = CopyOnWrite(op);
n->alloc_buffers = std::move(alloc_buffers);
n->reads = std::move(reads);
n->writes = std::move(writes);
n->iter_vars = std::move(block_vars);
n->body = std::move(body);
n->init = std::move(init);
return Stmt(n);
}
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<BufferStoreNode>();
ICHECK(op != nullptr);
auto it = buffer_map_.find(op->buffer);
if (it == buffer_map_.end()) {
return GetRef<BufferStore>(op);
} else {
auto n = CopyOnWrite(op);
n->buffer = it->second;
return Stmt(n);
}

PrimExpr value = VisitExpr(op->value);
Array<PrimExpr> indices =
MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });

auto n = CopyOnWrite(op);
n->buffer = it->second;
n->value = std::move(value);
n->indices = std::move(indices);
return Stmt(n);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<BufferLoadNode>();
ICHECK(op != nullptr);
auto it = buffer_map_.find(op->buffer);
if (it == buffer_map_.end()) {
return GetRef<BufferLoad>(op);
} else {
auto n = make_object<BufferLoadNode>(*op);
n->buffer = it->second;
return PrimExpr(n);
}

Array<PrimExpr> indices =
MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });

auto n = CopyOnWrite(op);
n->buffer = it->second;
n->indices = std::move(indices);
return PrimExpr(n);
}

PrimExpr VisitExpr_(const VarNode* op) final {
Expand All @@ -167,21 +158,24 @@ class PrimFuncSpecializer : public StmtExprMutator {
}

private:
Buffer MutateBuffer(Buffer buffer) const {
BufferNode* buffer_ptr = buffer.CopyOnWrite();
Array<PrimExpr> new_shape, new_stride;
new_shape.reserve(buffer_ptr->shape.size());
new_shape.reserve(buffer_ptr->strides.size());
for (const auto& dim : buffer_ptr->shape) {
new_shape.push_back(Substitute(dim, var_map_));
}
for (const auto& stride : buffer_ptr->strides) {
new_shape.push_back(Substitute(stride, var_map_));
Buffer MutateBuffer(const Buffer& buffer) const {
Array<PrimExpr> shape =
MutateArray(buffer->shape, [this](const PrimExpr& e) { return Substitute(e, var_map_); });
Array<PrimExpr> strides =
MutateArray(buffer->strides, [this](const PrimExpr& e) { return Substitute(e, var_map_); });

PrimExpr elem_offset = Substitute(buffer->elem_offset, var_map_);

if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) &&
buffer->strides.same_as(strides)) {
return buffer;
} else {
auto n = make_object<BufferNode>(*buffer.get());
n->elem_offset = std::move(elem_offset);
n->shape = std::move(shape);
n->strides = std::move(strides);
return Buffer(n);
}
buffer_ptr->elem_offset = Substitute(buffer_ptr->elem_offset, var_map_);
buffer_ptr->shape = std::move(new_shape);
buffer_ptr->strides = std::move(new_stride);
return buffer;
}

Range MutateRange(const Range& range) {
Expand All @@ -190,21 +184,7 @@ class PrimFuncSpecializer : public StmtExprMutator {
if (min.same_as(range->min) && extent.same_as(range->extent)) {
return range;
} else {
ObjectPtr<RangeNode> n = CopyOnWrite(range.get());
n->min = std::move(min);
n->extent = std::move(extent);
return Range(n);
}
}

IterVar MutateIterVar(const IterVar& iter_var) {
Range range = MutateRange(iter_var->dom);
if (range.same_as(iter_var->dom)) {
return iter_var;
} else {
auto n = CopyOnWrite(iter_var.get());
n->dom = std::move(range);
return IterVar(n);
return Range::FromMinExtent(std::move(min), std::move(extent));
}
}

Expand All @@ -213,6 +193,7 @@ class PrimFuncSpecializer : public StmtExprMutator {
if (buf.same_as(alloc_buf)) {
return alloc_buf;
} else {
ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end());
buffer_map_[alloc_buf] = buf;
return buf;
}
Expand All @@ -226,26 +207,21 @@ class PrimFuncSpecializer : public StmtExprMutator {
if (it == buffer_map_.end() && region.same_as(buffer_region->region)) {
return buffer_region;
} else {
auto n = CopyOnWrite(buffer_region.get());
n->buffer = it->second;
n->region = std::move(region);
return BufferRegion(n);
return BufferRegion(it->second, std::move(region));
}
}

private:
/*! \brief The vars to be substitute and their values */
VarMap var_map_;
const VarMap& var_map_;
/*! \brief map from old buffer to mutated buffer */
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
};

/**************** Implementation ****************/

PrimFunc Specialize(PrimFunc func, const Var& param, const Buffer& specific_buf) {
void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer& specific_buf,
VarMap* var_map) {
// preliminaries
tir::ExprDeepEqual equal;
VarMap var_map;

auto it = func->buffer_map.find(param);
CHECK(it != func->buffer_map.end())
Expand All @@ -259,23 +235,26 @@ PrimFunc Specialize(PrimFunc func, const Var& param, const Buffer& specific_buf)
<< "TypeError: The signature of target buffer exprected an independent Var, but got "
<< old_expr << ".";
const Var& var = Downcast<Var>(old_expr);
auto it = var_map.find(var);
if (it != var_map.end()) {
auto it = var_map->find(var);
if (it != var_map->end()) {
CHECK(equal(it->second, new_expr))
<< "ValueError: The assigned value of var " << var << " mismatched. " << it->second
<< " vs. " << new_expr << ".";
} else {
var_map[var] = new_expr;
(*var_map)[var] = new_expr;
}
}
};

// Check buffer dimensions
CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size() &&
specific_buf->strides.size() == buf_to_specialize->strides.size())
CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size())
<< "ValueError: The buffer dimensions mismatched" << buf_to_specialize->shape.size()
<< " vs. " << specific_buf->shape.size() << ".";

CHECK(specific_buf->strides.size() == buf_to_specialize->strides.size())
<< "ValueError: The buffer strides dimensions mismatched" << buf_to_specialize->strides.size()
<< " vs. " << specific_buf->strides.size() << ".";

// Updating var mapping using specific_expr
for (size_t i = 0; i < specific_buf->shape.size(); ++i) {
build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]);
Expand All @@ -284,42 +263,48 @@ PrimFunc Specialize(PrimFunc func, const Var& param, const Buffer& specific_buf)
build_var_mapping(specific_buf->strides[i], buf_to_specialize->strides[i]);
}
build_var_mapping(specific_buf->elem_offset, buf_to_specialize->elem_offset);
// Specialize function with var mapping
return PrimFuncSpecializer::Specialize(func, var_map);

// Check data_alignment and offset_factor.
// These two signatures are int, so we do not need map them.
CHECK_EQ(specific_buf->data_alignment, buf_to_specialize->data_alignment)
<< "ValueError: The buffer data_alignment mismatched" << buf_to_specialize->data_alignment
<< " vs. " << specific_buf->data_alignment << ".";

CHECK_EQ(specific_buf->offset_factor, buf_to_specialize->offset_factor)
<< "ValueError: The buffer offset_factor mismatched" << buf_to_specialize->offset_factor
<< " vs. " << specific_buf->offset_factor << ".";
}

PrimFunc Specialize(PrimFunc func, const Var& param, const PrimExpr& specific_expr) {
// preliminaries
VarMap var_map;
void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimExpr& specific_expr,
VarMap* var_map) {
// check param is in PrimFunc's parameters
CHECK(IsParam(func, param)) << "ValueError: Specialize expects param to be in PrimFunc's params";
// specialize a param not in buffer_map
CHECK_EQ(func->buffer_map.count(param), 0)
<< "ValueError: Specialize expects param to not be in PrimFunc's buffer_map";
// build var mapping using specific_expr
var_map[param] = specific_expr;
// Specialize function with var mapping
return PrimFuncSpecializer::Specialize(std::move(func), var_map);
(*var_map)[param] = specific_expr;
}

/**************** FFI ****************/

TVM_REGISTER_GLOBAL("tir.Specialize")
.set_body_typed<PrimFunc(PrimFunc, Map<Var, ObjectRef>)>([](PrimFunc func,
Map<Var, ObjectRef> param_map) {
VarMap var_map;
for (const auto& kv : param_map) {
const Var& param = kv.first;
const ObjectRef& instance = kv.second;
if (instance->IsInstance<BufferNode>()) {
func = Specialize(std::move(func), param, Downcast<Buffer>(instance));
UpdateSpecializeVarMap(func, param, Downcast<Buffer>(instance), &var_map);
} else if (instance->IsInstance<PrimExprNode>()) {
func = Specialize(std::move(func), param, Downcast<PrimExpr>(instance));
UpdateSpecializeVarMap(func, param, Downcast<PrimExpr>(instance), &var_map);
} else {
LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got "
<< instance->GetTypeKey();
}
}
return func;
return PrimFuncSpecializer::Specialize(std::move(func), std::move(var_map));
});

} // namespace tir
Expand Down
Loading

0 comments on commit b629c0b

Please sign in to comment.