Skip to content

Commit

Permalink
[BYOC] Enable bfloat16 in DNNL BYOC (#11111)
Browse files Browse the repository at this point in the history
* refine the code style (#10112)

* support more data types in oneDNN BYOC

* consider dtype when query layout

* support more translation of blocked layout

* refine log for invalid layout transform

* reset N and C for the weights

* support multi-blocking in TransDims2Plain()

* add tests for bf16 oneDNN BYOC

* unregister 'round' OP in oneDNN BYOC

* restore the criteria for fp32 tests

* disable test_prune_dnnl_subgraph for bf16

* fix typo in dnnl.py

* delete tag::format_tag_last

* delete 'is_weight' in layout2tag()

* reuse dtype_dl2dnnl()

* fix lint errors

* change to WARNING for invalid laytout transform

* skip bf16 tests if AVX512 is unavailable
  • Loading branch information
yangulei authored May 26, 2022
1 parent d519b03 commit 8135860
Show file tree
Hide file tree
Showing 9 changed files with 758 additions and 174 deletions.
4 changes: 2 additions & 2 deletions cmake/modules/contrib/DNNL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ if((USE_DNNL_CODEGEN STREQUAL "ON") OR (USE_DNNL_CODEGEN STREQUAL "JSON"))
add_definitions(-DUSE_JSON_RUNTIME=1)
tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc)
list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC})
list(APPEND COMPILER_SRCS ${JSON_RELAY_CONTRIB_SRC})

find_library(EXTERN_LIBRARY_DNNL dnnl)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL})
tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc)
tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc
src/runtime/contrib/dnnl/dnnl_utils.cc)
list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC})
message(STATUS "Build with DNNL JSON runtime: " ${EXTERN_LIBRARY_DNNL})
elseif(USE_DNNL_CODEGEN STREQUAL "C_SRC")
Expand Down
24 changes: 12 additions & 12 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -862,18 +862,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s
Span span = Span());

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
static const Op& op = Op::Get("tir." #OpName); \
if (x.dtype().is_bfloat16()) { \
DataType srcType = x.dtype(); \
DataType dstType(kDLFloat, 32, srcType.lanes()); \
PrimExpr castX = tir::Cast(dstType, {x}, span); \
PrimExpr result = tir::Call(dstType, op, {castX}, span); \
return tir::Cast(srcType, {result}, span); \
} else { \
return tir::Call(x.dtype(), op, {x}, span); \
} \
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
static const Op& op = Op::Get("tir." #OpName); \
if (x.dtype().is_bfloat16()) { \
DataType bf16_dtype = x.dtype(); \
DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \
PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \
PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \
return tir::Cast(bf16_dtype, {result_fp32}, span); \
} else { \
return tir::Call(x.dtype(), op, {x}, span); \
} \
}

TVM_DECLARE_INTRIN_UNARY(exp);
Expand Down
25 changes: 23 additions & 2 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def _func_wrapper(expr):
_register_external_op_helper("exp")
_register_external_op_helper("log")
_register_external_op_helper("sqrt")
_register_external_op_helper("round")
_register_external_op_helper("nn.relu")
_register_external_op_helper("nn.leaky_relu")
_register_external_op_helper("tanh")
Expand Down Expand Up @@ -212,7 +211,7 @@ def pattern_table():


def get_optimal_layout_for_conv(
data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups
data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups, dtype
):
"""Get the optimal layout of dnnl, given shape of conv2d.
Expand All @@ -236,6 +235,7 @@ def get_optimal_layout_for_conv(
strides,
dilates,
groups,
dtype,
)


Expand All @@ -249,6 +249,7 @@ def get_optimal_layout_for_conv_transpose(
strides,
dilates,
groups,
dtype,
):
"""Get the optimal layout of dnnl, given shape of tranposed conv2d.
Expand All @@ -274,6 +275,7 @@ def get_optimal_layout_for_conv_transpose(
strides,
dilates,
groups,
dtype,
)


Expand All @@ -292,6 +294,21 @@ def get_shape(tensor):
raise TypeError("Unsupport data type: %s" % type(tensor))


def get_dtype(tensor):
"""Get tensor's dtype."""
if isinstance(tensor, relay.expr.Var):
return tensor.type_annotation.dtype
if isinstance(tensor, relay.expr.Constant):
return tensor.data.dtype
if isinstance(tensor, tvm.ir.tensor_type.TensorType):
return tensor.dtype
if isinstance(tensor, tvm.ir.container.Array):
return tensor[-1].dtype
if isinstance(tensor, relay.expr.Call):
return tensor.checked_type.dtype
raise TypeError("Unsupport data type: %s" % type(tensor))


def tag2layout(input_data, is_weight=False, conv_type="Conv1D"):
"""Transfer layout, denoted with `a, b, c, d, e`,
into valid layout (NCHW / OIHW) of TVM."""
Expand Down Expand Up @@ -353,6 +370,7 @@ def alter_conv(attrs, inputs, tinfos, out_type):
paddings = ",".join([str(x) for x in attrs.get_int_tuple("padding")])
strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")])
dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")])
dtype = get_dtype(weight)
new_attrs = dict(attrs)
conv_type = type(attrs).__name__.split("Attrs")[0]

Expand All @@ -365,6 +383,7 @@ def alter_conv(attrs, inputs, tinfos, out_type):
strides,
dilates,
groups,
dtype,
)
src_df, weight_df, dst_df = res.split(",")
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
Expand All @@ -389,6 +408,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")])
dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")])
groups = str(attrs.groups)
dtype = get_dtype(weight)
new_attrs = dict(attrs)
conv_type = type(attrs).__name__.split("Attrs")[0]

Expand All @@ -402,6 +422,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
strides,
dilates,
groups,
dtype,
)
src_df, weight_df, dst_df = res.split(",")
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
Expand Down
33 changes: 17 additions & 16 deletions src/relay/backend/contrib/dnnl/query_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,17 @@
#include <regex>
#include <sstream>

#include "../../../../runtime/contrib/dnnl/dnnl_utils.h"
#include "../../utils.h"
#include "dnnl.hpp"

using dim_t = dnnl_dim_t;
using dims_t = dnnl_dims_t;

namespace tvm {
namespace relay {
namespace contrib {

using dim_t = dnnl_dim_t;
using dims_t = dnnl_dims_t;
using tvm::runtime::contrib::dtype_dl2dnnl;

template <typename T, typename U>
inline void array_set(T* arr, const U& val, size_t size) {
for (size_t i = 0; i < size; ++i) arr[i] = static_cast<T>(val);
Expand Down Expand Up @@ -192,15 +193,14 @@ void check_layout(bool var, bool ref) {
std::string get_optimal_layout_for_conv(std::string data_layout, std::string kernel_layout,
std::string weight_shape, std::string out_shape,
std::string paddings, std::string strides,
std::string dilates, std::string G) {
std::string dilates, std::string G, std::string dtype) {
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true);
check_shapes({weight_shape, out_shape, paddings, strides, dilates, G});

dnnl::engine eng(dnnl::engine::kind::cpu, 0);
dnnl::stream s(eng);
using tag = dnnl::memory::format_tag;
using dt = dnnl::memory::data_type;

dnnl::memory::dim groups = std::stoi(G);
dnnl::memory::dims weight_dims_ = str2dims(weight_shape);
Expand Down Expand Up @@ -249,9 +249,10 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker
dnnl::memory::dims conv_padding_l = padding_dims_l;
dnnl::memory::dims conv_padding_r = padding_dims_r;

auto conv_src_md = dnnl::memory::desc({conv_src_dims}, dt::f32, tag::any);
auto conv_weights_md = dnnl::memory::desc({conv_weights_dims}, dt::f32, tag::any);
auto conv_dst_md = dnnl::memory::desc({conv_dst_dims}, dt::f32, tag::any);
auto dnnl_dtype = dtype_dl2dnnl(tvm::runtime::String2DLDataType(dtype));
auto conv_src_md = dnnl::memory::desc({conv_src_dims}, dnnl_dtype, tag::any);
auto conv_weights_md = dnnl::memory::desc({conv_weights_dims}, dnnl_dtype, tag::any);
auto conv_dst_md = dnnl::memory::desc({conv_dst_dims}, dnnl_dtype, tag::any);

auto conv_desc = dnnl::convolution_forward::desc(
dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md,
Expand All @@ -276,15 +277,14 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
std::string weight_shape, std::string out_shape,
std::string paddings, std::string output_paddings,
std::string strides, std::string dilates,
std::string G) {
std::string G, std::string dtype) {
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
check_layout(std::regex_match(kernel_layout, std::regex("(G?)((IO)|(OI))(D?)(H?)W")), true);
check_shapes({weight_shape, out_shape, paddings, output_paddings, strides, dilates, G});

dnnl::engine eng(dnnl::engine::kind::cpu, 0);
dnnl::stream s(eng);
using tag = dnnl::memory::format_tag;
using dt = dnnl::memory::data_type;

dnnl::memory::dim groups = std::stoi(G);
dnnl::memory::dims weight_dims_ = str2dims(weight_shape);
Expand Down Expand Up @@ -338,9 +338,10 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
dnnl::memory::dims deconv_padding_l = padding_dims_l;
dnnl::memory::dims deconv_padding_r = padding_dims_r;

auto deconv_src_md = dnnl::memory::desc({deconv_src_dims}, dt::f32, tag::any);
auto deconv_weights_md = dnnl::memory::desc({deconv_weights_dims}, dt::f32, tag::any);
auto deconv_dst_md = dnnl::memory::desc({deconv_dst_dims}, dt::f32, tag::any);
auto dnnl_dtype = dtype_dl2dnnl(tvm::runtime::String2DLDataType(dtype));
auto deconv_src_md = dnnl::memory::desc({deconv_src_dims}, dnnl_dtype, tag::any);
auto deconv_weights_md = dnnl::memory::desc({deconv_weights_dims}, dnnl_dtype, tag::any);
auto deconv_dst_md = dnnl::memory::desc({deconv_dst_dims}, dnnl_dtype, tag::any);

auto deconv_desc = dnnl::deconvolution_forward::desc(
dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, deconv_src_md,
Expand All @@ -364,13 +365,13 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = get_optimal_layout_for_conv(args[0], args[1], args[2], args[3], args[4], args[5],
args[6], args[7]);
args[6], args[7], args[8]);
});

TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv_transpose")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = get_optimal_layout_for_conv_transpose(args[0], args[1], args[2], args[3], args[4],
args[5], args[6], args[7], args[8]);
args[5], args[6], args[7], args[8], args[9]);
});

} // namespace contrib
Expand Down
Loading

0 comments on commit 8135860

Please sign in to comment.