Skip to content

Commit

Permalink
[TENSORIR] Add from_legacy_te_schdule attr to TE PrimFuncs (apache#…
Browse files Browse the repository at this point in the history
…8641)

* [TENSORIR] Add `from_legacy_te_schdule` attr to TE PrimFuncs

The `from_legacy_te_schedule` marks PrimFuncs created from TE
scheduling. Passes that only operate on TE scheduling check this attrs
and no op if it is not found. If `from_legacy_te_schedule` is false or
not set, then it is assumed that the PrimFunc is from TensorIR. Passes
specific to TensorIR now check for the absence of this attr.

* formatting

* enable passes regardless of te or not
  • Loading branch information
tkonolige authored and ylc committed Sep 29, 2021
1 parent e4cc296 commit 07535b8
Show file tree
Hide file tree
Showing 18 changed files with 204 additions and 54 deletions.
31 changes: 16 additions & 15 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ transform::Pass Filter(FCond fcond) {
return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}

Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
transform::PassContext pass_ctx = transform::PassContext::Current();

bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
Expand Down Expand Up @@ -214,17 +214,14 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for
Array<tvm::transform::Pass> pass_list = user_lower_phase0;

// PHASE 1
if (for_te_schedule) {
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
} else {
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::FlattenBuffer());
}
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
Expand Down Expand Up @@ -288,6 +285,10 @@ IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const
tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
f = WithAttr(std::move(f), "global_symbol", runtime::String(name));

// Mark this schedule as being converted from an TE schedule. Makes sure that
// the correct TE passes are run.
f = WithAttr(std::move(f), "from_legacy_te_schedule", Bool(true));

bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();

if (noalias) {
Expand All @@ -311,7 +312,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
});

IRModule LowerModule(IRModule mod, bool simple_mode) {
Array<transform::Pass> pass_list = CreatePassList(simple_mode, false);
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(std::move(mod), pass_list);
}

Expand All @@ -331,7 +332,7 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));

// Get the pass list
Array<transform::Pass> pass_list = CreatePassList(simple_mode, false);
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(std::move(mod), pass_list);
}

Expand All @@ -353,7 +354,7 @@ IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const std
const std::unordered_map<te::Tensor, tir::Buffer>& binds, bool simple_mode) {
IRModule mod = ScheduleToModule(std::move(sch), args, name, binds);
// Get the legacy TE pass list
Array<transform::Pass> pass_list = CreatePassList(simple_mode, true);
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(mod, pass_list);
}

Expand Down
4 changes: 3 additions & 1 deletion src/te/schedule/schedule_postproc_to_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list, Stmt body,
}

body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body));
return tir::PrimFunc(params, body, VoidType(), buffer_map);
// We mark this PrimFunc as coming from a TE schedule
return WithAttr(tir::PrimFunc(params, body, VoidType(), buffer_map), "from_legacy_te_schedule",
Bool(true));
}

TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc")
Expand Down
16 changes: 11 additions & 5 deletions src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "../../support/arena.h"
#include "../../support/utils.h"
#include "../schedule/utils.h"
#include "ir_utils.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -452,11 +453,16 @@ class BufferCompactor : public StmtExprMutator {
};

PrimFunc CompactBufferAllocation(PrimFunc f) {
PrimFuncNode* fptr = f.CopyOnWrite();
std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
BufferAccessRegionCollector::Collect(f);
fptr->body = BufferCompactor::Compact(f, region);
return f;
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
BufferAccessRegionCollector::Collect(f);
fptr->body = BufferCompactor::Compact(f, region);
return f;
} else {
return f;
}
}

namespace transform {
Expand Down
13 changes: 10 additions & 3 deletions src/tir/transforms/convert_blocks_to_opaque.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "ir_utils.h"

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -83,9 +85,14 @@ class OpaqueBlockConverter : public StmtExprMutator {
};

PrimFunc ConvertBlocksToOpaque(PrimFunc f) {
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = OpaqueBlockConverter::Substitute(f);
return f;
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = OpaqueBlockConverter::Substitute(f);
return f;
} else {
return f;
}
}

namespace transform {
Expand Down
12 changes: 9 additions & 3 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/tir/transform.h>

#include "../../support/utils.h"
#include "ir_utils.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -151,9 +152,14 @@ class BufferFlattener : public StmtExprMutator {
};

PrimFunc FlattenBuffer(PrimFunc f) {
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = BufferFlattener::Flatten(f);
return f;
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = BufferFlattener::Flatten(f);
return f;
} else {
return f;
}
}

namespace transform {
Expand Down
13 changes: 10 additions & 3 deletions src/tir/transforms/inject_prefetch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

#include <unordered_set>

#include "ir_utils.h"

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -96,9 +98,14 @@ namespace transform {

Pass InjectPrefetch() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = PrefetchInjector()(std::move(n->body));
return f;
// Only apply this pass to TIR from TE schedules
if (IsFromLegacyTESchedule(f)) {
auto* n = f.CopyOnWrite();
n->body = PrefetchInjector()(std::move(n->body));
return f;
} else {
return f;
}
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {});
}
Expand Down
5 changes: 5 additions & 0 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,5 +244,10 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region
return result;
}

Bool IsFromLegacyTESchedule(PrimFunc f) {
Optional<Bool> from_legacy_te_schedule = f->GetAttr("from_legacy_te_schedule", Bool(false));
return from_legacy_te_schedule.value();
}

} // namespace tir
} // namespace tvm
11 changes: 11 additions & 0 deletions src/tir/transforms/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/device_api.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>

#include <limits>
Expand Down Expand Up @@ -213,6 +214,16 @@ Array<PrimExpr> ConvertIndices(const MatchBufferRegion& match_buffer,
*/
Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region);

/*!
* \brief Check if a given PrimFunc originated from a TE schedule.
*
* Internally this checks for the `from_legacy_te_schedule` attr of the PrimFunc.
*
* \param f PrimFunc to check
* \return Whether or not the PrimFunc was created from a te schedule
*/
Bool IsFromLegacyTESchedule(PrimFunc f);

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_
13 changes: 10 additions & 3 deletions src/tir/transforms/lower_init_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "ir_utils.h"

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -63,9 +65,14 @@ class InitBlockLower : public StmtMutator {
};

PrimFunc LowerInitBlock(PrimFunc func) {
auto fptr = func.CopyOnWrite();
fptr->body = InitBlockLower()(std::move(fptr->body));
return func;
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(func)) {
auto fptr = func.CopyOnWrite();
fptr->body = InitBlockLower()(std::move(fptr->body));
return func;
} else {
return func;
}
}

namespace transform {
Expand Down
15 changes: 11 additions & 4 deletions src/tir/transforms/plan_update_buffer_allocation_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "ir_utils.h"

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -145,10 +147,15 @@ class BufferAllocationLocator : public StmtExprMutator {
};

PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
auto fptr = func.CopyOnWrite();
BufferAllocationLocator locator(func);
fptr->body = locator(fptr->body);
return func;
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(func)) {
auto fptr = func.CopyOnWrite();
BufferAllocationLocator locator(func);
fptr->body = locator(fptr->body);
return func;
} else {
return func;
}
}

namespace transform {
Expand Down
20 changes: 13 additions & 7 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -500,13 +500,19 @@ class StorageFlattener : public StmtExprMutator {
};

PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) {
auto fptr = func.CopyOnWrite();

IRVisitorWithAnalyzer bound_analyzer;
bound_analyzer(fptr->body);
fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes,
&bound_analyzer)(std::move(fptr->body));
return func;
// Only apply this pass to TIR from TE schedules
Optional<Bool> from_legacy_te_schedule = func->GetAttr("from_legacy_te_schedule", Bool(false));
if (from_legacy_te_schedule.value()) {
auto fptr = func.CopyOnWrite();

IRVisitorWithAnalyzer bound_analyzer;
bound_analyzer(fptr->body);
fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes,
&bound_analyzer)(std::move(fptr->body));
return func;
} else {
return func;
}
}

namespace transform {
Expand Down
27 changes: 23 additions & 4 deletions tests/python/unittest/test_lower_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,25 @@ def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:

@tvm.script.tir
class LoweredModule:
def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
# function attr dict
tir.func_attr(
{"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}
)
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])
# body
for x, y in tir.grid(128, 128):
C.data[x * 128 + y] = 0.0
for k in tir.serial(0, 128):
C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y) + tir.load(
"float32", A.data, x * 128 + k
) * tir.load("float32", B.data, y * 128 + k)


@tvm.script.tir
class LoweredTIRModule:
def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
# function attr dict
tir.func_attr({"global_symbol": "main", "tir.noalias": True})
Expand Down Expand Up @@ -83,7 +102,7 @@ def test_lower_build_te_schedule():
def test_lower_build_tir_func():
# check lowering
ir_mod = tvm.lower(matmul)
tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule())
# check building
mod = tvm.build(matmul, target="llvm")
_check_module_with_numpy(mod)
Expand All @@ -95,16 +114,16 @@ def test_lower_build_tir_module():
ir_mod = IRModule({"main": func})
# check lowering
lowered_mod = tvm.lower(ir_mod)
tvm.ir.assert_structural_equal(lowered_mod, LoweredModule())
tvm.ir.assert_structural_equal(lowered_mod, LoweredTIRModule())
# check building
mod = tvm.build(ir_mod, target="llvm")
_check_module_with_numpy(mod)


def test_lower_build_lowered_module():
# check lowering
ir_mod = tvm.lower(LoweredModule())
tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
ir_mod = tvm.lower(LoweredTIRModule())
tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule())
# check building
mod = tvm.build(ir_mod, target="llvm")
_check_module_with_numpy(mod)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import tir
from tvm import tir, te
from tvm.script import ty


Expand Down Expand Up @@ -371,6 +371,15 @@ def test_match_buffer():
_check(match_buffer_func, compacted_match_buffer_func)


def test_lower_te():
x = te.placeholder((1,))
y = te.compute((1,), lambda i: x[i] + 2)
s = te.create_schedule(y.op)
orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
mod = tvm.tir.transform.CompactBufferAllocation()(orig_mod)
tvm.ir.assert_structural_equal(mod, orig_mod) # CompactBufferAllocation should do nothing on TE


if __name__ == "__main__":
test_elementwise()
test_unschedulable_block()
Expand Down
Loading

0 comments on commit 07535b8

Please sign in to comment.