Skip to content

Commit

Permalink
[TIR] Refactor BF16Legalize
Browse files Browse the repository at this point in the history
This PR refactors BF16Legalize to enable more f32 computations.
We also split the BF16Legalize into two steps.

- BF16ComputeLegalize changes all computation to f32 while keeping
  the external BF16 storages.
- BF16StorageLegalize changes all storage to u16.

Now BF16 kernels accept tvm.nd.array that are created as bfloat16 type.
  • Loading branch information
tqchen committed Mar 27, 2023
1 parent 0d0d2f0 commit d2011c4
Show file tree
Hide file tree
Showing 15 changed files with 615 additions and 423 deletions.
10 changes: 8 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,17 @@ TVM_DLL Pass CombineContextCall();
TVM_DLL Pass NarrowDataType(int target_bits);

/*!
* \brief Legalize bf16 typed Ops. Add a cast to fp32
* \brief Legalize bf16 compute Ops. Add a cast to fp32
* before Ops, then add a cast back to bf16.
* \return The pass.
*/
TVM_DLL Pass BF16Legalize();
TVM_DLL Pass BF16ComputeLegalize();

/*!
* \brief Legalize bf16 storage types to u16.
* \return The pass.
*/
TVM_DLL Pass BF16StorageLegalize();

/*!
* \brief Rewrite the pointer content type of arguments,
Expand Down
6 changes: 1 addition & 5 deletions include/tvm/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast",
inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "tensor",
std::string tag = kElementWise) {
return compute(
x->shape,
[&](const Array<Var>& i) {
return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)});
},
name, tag);
x->shape, [&](const Array<Var>& i) { return reinterpret(type, x(i)); }, name, tag);
}

/*!
Expand Down
45 changes: 6 additions & 39 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,59 +286,26 @@ def RemoveStoreUndef():
return _ffi_api.RemoveStoreUndef() # type: ignore


def BF16Legalize():
"""Legalize bf16 typed Ops.
Runs BF16Promote, BF16CastElimination and BF16TypeLowering
def BF16ComputeLegalize():
"""Legalize bf16 compute Ops.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BF16Legalize() # type: ignore
return _ffi_api.BF16ComputeLegalize() # type: ignore


def BF16Promote():
"""Promote bf16 to fp32. Add a cast to fp32
before Ops, then add a cast back to bf16.
def BF16StorageLegalize():
"""Legalize bf16 storage types to u16.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BF16Promote() # type: ignore


def BF16CastElimination():
"""Eliminate verbose casting between fp32 and bf16
Checks if the AST has the pattern:
castto32(castto16(some_fp32_op(...)))
The verbose casting is generated by BF16Promote for multiple
bf16 Ops in a row. e.g.:
X[i] + Y[i] + T[i] =>
bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i])))
After this pass:
bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BF16CastElimination() # type: ignore


def BF16TypeLowering():
"""Replace all bf16 type with uint16. Also lower the casting
between fp32 and bf16
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BF16TypeLowering() # type: ignore
return _ffi_api.BF16StorageLegalize() # type: ignore


def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False):
Expand Down
3 changes: 2 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());

Expand Down Expand Up @@ -605,6 +605,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());

return transform::Sequential(mixed_pass_list);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::InjectVirtualThread());
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
// Phase 2
Expand Down
1 change: 0 additions & 1 deletion src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ runtime::Module Build(IRModule mod, Target target) {
.value()) {
mod = tir::transform::SkipAssert()(mod);
}

auto target_attr_map = tvm::TargetKind::GetAttrMap<FTVMTIRToRuntime>("TIRToRuntime");
if (target_attr_map.count(target->kind)) {
return target_attr_map[target->kind](mod, target);
Expand Down
4 changes: 4 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,10 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va
llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) {
llvm::Type* target = DTypeToLLVMType(to);
if (value->getType() == target) return value;
// TODO(tvm-team): consider add native support
ICHECK(!from.is_bfloat16()) << "BF16 needs to be storaged lowered first";
ICHECK(!to.is_bfloat16()) << "BF16 needs to be storaged lowered first";

if (to.is_handle()) {
return builder_->CreateBitCast(value, target);
} else if (to.is_uint() && to.bits() == 1) {
Expand Down
1 change: 0 additions & 1 deletion src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
if (tm->getTargetTriple().isOSDarwin()) {
module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
}

std::string verify_errors_storage;
llvm::raw_string_ostream verify_errors(verify_errors_storage);
LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors))
Expand Down
2 changes: 2 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) {
// reinterpret
PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) {
if (value.dtype() == t) return value;
ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes())
<< "Bitcast requires size match " << t << " vs " << value.dtype();
return tir::Call(t, tir::builtin::reinterpret(), {value}, span);
}

Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) ==
IntImm(DataType::UInt(16), buffer->dtype.lanes()));
if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4) || buffer->dtype == DataType::UInt(16))) {
buffer->dtype == DataType::UInt(4))) {
auto type_msg = tvm::tir::StringImm(type_err_msg.str());
asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
Expand Down
Loading

0 comments on commit d2011c4

Please sign in to comment.