diff --git a/cmake/modules/contrib/DNNL.cmake b/cmake/modules/contrib/DNNL.cmake index 9e36f39891e1..6642719cb485 100644 --- a/cmake/modules/contrib/DNNL.cmake +++ b/cmake/modules/contrib/DNNL.cmake @@ -19,11 +19,11 @@ if((USE_DNNL_CODEGEN STREQUAL "ON") OR (USE_DNNL_CODEGEN STREQUAL "JSON")) add_definitions(-DUSE_JSON_RUNTIME=1) tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc) list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) - list(APPEND COMPILER_SRCS ${JSON_RELAY_CONTRIB_SRC}) find_library(EXTERN_LIBRARY_DNNL dnnl) list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL}) - tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc) + tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc + src/runtime/contrib/dnnl/dnnl_utils.cc) list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC}) message(STATUS "Build with DNNL JSON runtime: " ${EXTERN_LIBRARY_DNNL}) elseif(USE_DNNL_CODEGEN STREQUAL "C_SRC") diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 5b63016d2f9d..905c67f1c5b0 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -862,18 +862,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s Span span = Span()); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ - static const Op& op = Op::Get("tir." #OpName); \ - if (x.dtype().is_bfloat16()) { \ - DataType srcType = x.dtype(); \ - DataType dstType(kDLFloat, 32, srcType.lanes()); \ - PrimExpr castX = tir::Cast(dstType, {x}, span); \ - PrimExpr result = tir::Call(dstType, op, {castX}, span); \ - return tir::Cast(srcType, {result}, span); \ - } else { \ - return tir::Call(x.dtype(), op, {x}, span); \ - } \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ + static const Op& op = Op::Get("tir." #OpName); \ + if (x.dtype().is_bfloat16()) { \ + DataType bf16_dtype = x.dtype(); \ + DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \ + PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \ + PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \ + return tir::Cast(bf16_dtype, {result_fp32}, span); \ + } else { \ + return tir::Call(x.dtype(), op, {x}, span); \ + } \ } TVM_DECLARE_INTRIN_UNARY(exp); diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 72e004b86853..2e975cf49c88 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -85,7 +85,6 @@ def _func_wrapper(expr): _register_external_op_helper("exp") _register_external_op_helper("log") _register_external_op_helper("sqrt") -_register_external_op_helper("round") _register_external_op_helper("nn.relu") _register_external_op_helper("nn.leaky_relu") _register_external_op_helper("tanh") @@ -212,7 +211,7 @@ def pattern_table(): def get_optimal_layout_for_conv( - data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups + data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups, dtype ): """Get the optimal layout of dnnl, given shape of conv2d. @@ -236,6 +235,7 @@ def get_optimal_layout_for_conv( strides, dilates, groups, + dtype, ) @@ -249,6 +249,7 @@ def get_optimal_layout_for_conv_transpose( strides, dilates, groups, + dtype, ): """Get the optimal layout of dnnl, given shape of tranposed conv2d. @@ -274,6 +275,7 @@ def get_optimal_layout_for_conv_transpose( strides, dilates, groups, + dtype, ) @@ -292,6 +294,21 @@ def get_shape(tensor): raise TypeError("Unsupport data type: %s" % type(tensor)) +def get_dtype(tensor): + """Get tensor's dtype.""" + if isinstance(tensor, relay.expr.Var): + return tensor.type_annotation.dtype + if isinstance(tensor, relay.expr.Constant): + return tensor.data.dtype + if isinstance(tensor, tvm.ir.tensor_type.TensorType): + return tensor.dtype + if isinstance(tensor, tvm.ir.container.Array): + return tensor[-1].dtype + if isinstance(tensor, relay.expr.Call): + return tensor.checked_type.dtype + raise TypeError("Unsupport data type: %s" % type(tensor)) + + def tag2layout(input_data, is_weight=False, conv_type="Conv1D"): """Transfer layout, denoted with `a, b, c, d, e`, into valid layout (NCHW / OIHW) of TVM.""" @@ -353,6 +370,7 @@ def alter_conv(attrs, inputs, tinfos, out_type): paddings = ",".join([str(x) for x in attrs.get_int_tuple("padding")]) strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")]) dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")]) + dtype = get_dtype(weight) new_attrs = dict(attrs) conv_type = type(attrs).__name__.split("Attrs")[0] @@ -365,6 +383,7 @@ def alter_conv(attrs, inputs, tinfos, out_type): strides, dilates, groups, + dtype, ) src_df, weight_df, dst_df = res.split(",") new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type) @@ -389,6 +408,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type): strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")]) dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")]) groups = str(attrs.groups) + dtype = get_dtype(weight) new_attrs = dict(attrs) conv_type = type(attrs).__name__.split("Attrs")[0] @@ -402,6 +422,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type): strides, dilates, groups, + dtype, ) src_df, weight_df, dst_df = res.split(",") new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type) diff --git a/src/relay/backend/contrib/dnnl/query_layout.cc b/src/relay/backend/contrib/dnnl/query_layout.cc index 7fb1d824c702..3762c1906f40 100755 --- a/src/relay/backend/contrib/dnnl/query_layout.cc +++ b/src/relay/backend/contrib/dnnl/query_layout.cc @@ -34,16 +34,17 @@ #include #include +#include "../../../../runtime/contrib/dnnl/dnnl_utils.h" #include "../../utils.h" #include "dnnl.hpp" - -using dim_t = dnnl_dim_t; -using dims_t = dnnl_dims_t; - namespace tvm { namespace relay { namespace contrib { +using dim_t = dnnl_dim_t; +using dims_t = dnnl_dims_t; +using tvm::runtime::contrib::dtype_dl2dnnl; + template inline void array_set(T* arr, const U& val, size_t size) { for (size_t i = 0; i < size; ++i) arr[i] = static_cast(val); @@ -192,7 +193,7 @@ void check_layout(bool var, bool ref) { std::string get_optimal_layout_for_conv(std::string data_layout, std::string kernel_layout, std::string weight_shape, std::string out_shape, std::string paddings, std::string strides, - std::string dilates, std::string G) { + std::string dilates, std::string G, std::string dtype) { check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true); check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true); check_shapes({weight_shape, out_shape, paddings, strides, dilates, G}); @@ -200,7 +201,6 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker dnnl::engine eng(dnnl::engine::kind::cpu, 0); dnnl::stream s(eng); using tag = dnnl::memory::format_tag; - using dt = dnnl::memory::data_type; dnnl::memory::dim groups = std::stoi(G); dnnl::memory::dims weight_dims_ = str2dims(weight_shape); @@ -249,9 +249,10 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker dnnl::memory::dims conv_padding_l = padding_dims_l; dnnl::memory::dims conv_padding_r = padding_dims_r; - auto conv_src_md = dnnl::memory::desc({conv_src_dims}, dt::f32, tag::any); - auto conv_weights_md = dnnl::memory::desc({conv_weights_dims}, dt::f32, tag::any); - auto conv_dst_md = dnnl::memory::desc({conv_dst_dims}, dt::f32, tag::any); + auto dnnl_dtype = dtype_dl2dnnl(tvm::runtime::String2DLDataType(dtype)); + auto conv_src_md = dnnl::memory::desc({conv_src_dims}, dnnl_dtype, tag::any); + auto conv_weights_md = dnnl::memory::desc({conv_weights_dims}, dnnl_dtype, tag::any); + auto conv_dst_md = dnnl::memory::desc({conv_dst_dims}, dnnl_dtype, tag::any); auto conv_desc = dnnl::convolution_forward::desc( dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md, @@ -276,7 +277,7 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout, std::string weight_shape, std::string out_shape, std::string paddings, std::string output_paddings, std::string strides, std::string dilates, - std::string G) { + std::string G, std::string dtype) { check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true); check_layout(std::regex_match(kernel_layout, std::regex("(G?)((IO)|(OI))(D?)(H?)W")), true); check_shapes({weight_shape, out_shape, paddings, output_paddings, strides, dilates, G}); @@ -284,7 +285,6 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout, dnnl::engine eng(dnnl::engine::kind::cpu, 0); dnnl::stream s(eng); using tag = dnnl::memory::format_tag; - using dt = dnnl::memory::data_type; dnnl::memory::dim groups = std::stoi(G); dnnl::memory::dims weight_dims_ = str2dims(weight_shape); @@ -338,9 +338,10 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout, dnnl::memory::dims deconv_padding_l = padding_dims_l; dnnl::memory::dims deconv_padding_r = padding_dims_r; - auto deconv_src_md = dnnl::memory::desc({deconv_src_dims}, dt::f32, tag::any); - auto deconv_weights_md = dnnl::memory::desc({deconv_weights_dims}, dt::f32, tag::any); - auto deconv_dst_md = dnnl::memory::desc({deconv_dst_dims}, dt::f32, tag::any); + auto dnnl_dtype = dtype_dl2dnnl(tvm::runtime::String2DLDataType(dtype)); + auto deconv_src_md = dnnl::memory::desc({deconv_src_dims}, dnnl_dtype, tag::any); + auto deconv_weights_md = dnnl::memory::desc({deconv_weights_dims}, dnnl_dtype, tag::any); + auto deconv_dst_md = dnnl::memory::desc({deconv_dst_dims}, dnnl_dtype, tag::any); auto deconv_desc = dnnl::deconvolution_forward::desc( dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, deconv_src_md, @@ -364,13 +365,13 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout, TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = get_optimal_layout_for_conv(args[0], args[1], args[2], args[3], args[4], args[5], - args[6], args[7]); + args[6], args[7], args[8]); }); TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv_transpose") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = get_optimal_layout_for_conv_transpose(args[0], args[1], args[2], args[3], args[4], - args[5], args[6], args[7], args[8]); + args[5], args[6], args[7], args[8], args[9]); }); } // namespace contrib diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index dc2afecbaf91..f6a1c3b79080 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -33,6 +33,7 @@ #include "../json/json_node.h" #include "../json/json_runtime.h" #include "dnnl.hpp" +#include "dnnl_utils.h" namespace tvm { namespace runtime { @@ -66,8 +67,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // Fill in the input buffers. for (size_t i = 0; i < input_nodes_.size(); ++i) { auto eid = EntryID(input_nodes_[i], 0); - // TODO(@comaniac): Support other data lengths. - size_t offset_in_bytes = entry_out_mem_[eid].second * 4; + size_t offset_in_bytes = + entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) / 8); size_t buffer_size = GetDataSize(*data_entry_[eid]); write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size, offset_in_bytes); @@ -82,7 +83,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // Read output buffers. for (size_t i = 0; i < outputs_.size(); ++i) { auto eid = EntryID(outputs_[i]); - size_t offset_in_bytes = entry_out_mem_[eid].second * 4; + size_t offset_in_bytes = + entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) / 8); size_t buffer_size = GetDataSize(*data_entry_[eid]); read_from_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size, offset_in_bytes); @@ -90,7 +92,501 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } private: - // Build up the engine based on the input graph. + tag layout2tag(std::string layout) { + static const std::map str2tag = {{"nc", tag::nc}, + {"cn", tag::cn}, + {"tn", tag::tn}, + {"nt", tag::nt}, + {"ncw", tag::ncw}, + {"nwc", tag::nwc}, + {"nchw", tag::nchw}, + {"nhwc", tag::nhwc}, + {"chwn", tag::chwn}, + {"ncdhw", tag::ncdhw}, + {"ndhwc", tag::ndhwc}, + {"oi", tag::oi}, + {"io", tag::io}, + {"oiw", tag::oiw}, + {"owi", tag::owi}, + {"wio", tag::wio}, + {"iwo", tag::iwo}, + {"oihw", tag::oihw}, + {"hwio", tag::hwio}, + {"ohwi", tag::ohwi}, + {"ihwo", tag::ihwo}, + {"iohw", tag::iohw}, + {"oidhw", tag::oidhw}, + {"dhwio", tag::dhwio}, + {"odhwi", tag::odhwi}, + {"iodhw", tag::iodhw}, + {"idhwo", tag::idhwo}, + {"goiw", tag::goiw}, + {"gowi", tag::gowi}, + {"wigo", tag::wigo}, + {"gohwi", tag::gohwi}, + {"goihw", tag::goihw}, + {"hwigo", tag::hwigo}, + {"giohw", tag::giohw}, + {"goidhw", tag::goidhw}, + {"giodhw", tag::giodhw}, + {"godhwi", tag::godhwi}, + {"dhwigo", tag::dhwigo}, + {"tnc", tag::tnc}, + {"ntc", tag::ntc}, + {"ldnc", tag::ldnc}, + {"ldigo", tag::ldigo}, + {"ldgoi", tag::ldgoi}, + {"ldio", tag::ldio}, + {"ldoi", tag::ldoi}, + {"ldgo", tag::ldgo}, + {"nCdhw16c", tag::nCdhw16c}, + {"nCdhw4c", tag::nCdhw4c}, + {"nCdhw8c", tag::nCdhw8c}, + {"nChw16c", tag::nChw16c}, + {"nChw4c", tag::nChw4c}, + {"nChw8c", tag::nChw8c}, + {"nCw16c", tag::nCw16c}, + {"nCw4c", tag::nCw4c}, + {"nCw8c", tag::nCw8c}, + {"NCw16n16c", tag::NCw16n16c}, + {"NChw16n16c", tag::NChw16n16c}, + {"NCdhw16n16c", tag::NCdhw16n16c}, + {"NCdhw32n32c", tag::NCdhw32n32c}, + {"NChw32n32c", tag::NChw32n32c}, + {"IOhw16i16o", tag::IOhw16i16o}, + {"OI16i16o", tag::OI16i16o}, + {"OI16i32o", tag::OI16i32o}, + {"OI16i64o", tag::OI16i64o}, + {"OI8i16o2i", tag::OI8i16o2i}, + {"OI8i32o2i", tag::OI8i32o2i}, + {"OI8i64o2i", tag::OI8i64o2i}, + {"OI4i16o4i", tag::OI4i16o4i}, + {"OI4i32o4i", tag::OI4i32o4i}, + {"OI4i64o4i", tag::OI4i64o4i}, + {"Ohwi32o", tag::Ohwi32o}, + {"IOdhw16i16o", tag::IOdhw16i16o}, + {"gIOhw16i16o", tag::gIOhw16i16o}, + {"gOhwi32o", tag::gOhwi32o}, + {"Goidhw16g", tag::Goidhw16g}, + {"IOw16o16i", tag::IOw16o16i}, + {"OIw16i16o", tag::OIw16i16o}, + {"OIw16i32o", tag::OIw16i32o}, + {"OIw16i64o", tag::OIw16i64o}, + {"IOw16i16o", tag::IOw16i16o}, + {"gIOw16i16o", tag::gIOw16i16o}, + {"OIw16o16i", tag::OIw16o16i}, + {"Oiw16o", tag::Oiw16o}, + {"OIw4i16o4i", tag::OIw4i16o4i}, + {"OIw4i32o4i", tag::OIw4i32o4i}, + {"OIw4i64o4i", tag::OIw4i64o4i}, + {"OIw2i8o4i", tag::OIw2i8o4i}, + {"OIw4i4o", tag::OIw4i4o}, + {"OIw4o4i", tag::OIw4o4i}, + {"Oiw4o", tag::Oiw4o}, + {"OIw8i16o2i", tag::OIw8i16o2i}, + {"OIw8i32o2i", tag::OIw8i32o2i}, + {"OIw8i64o2i", tag::OIw8i64o2i}, + {"OIw8i8o", tag::OIw8i8o}, + {"OIw8o16i2o", tag::OIw8o16i2o}, + {"OIw8o8i", tag::OIw8o8i}, + {"OIw8o4i", tag::OIw8o4i}, + {"OIw16i16o4i", tag::OIw16i16o4i}, + {"OIw16i32o4i", tag::OIw16i32o4i}, + {"OIw16i48o4i", tag::OIw16i48o4i}, + {"OIw16i64o4i", tag::OIw16i64o4i}, + {"OIw16i16o2i", tag::OIw16i16o2i}, + {"OIw16i32o2i", tag::OIw16i32o2i}, + {"OIw16i48o2i", tag::OIw16i48o2i}, + {"OIw16i64o2i", tag::OIw16i64o2i}, + {"OIw16o16i2o", tag::OIw16o16i2o}, + {"Owi16o", tag::Owi16o}, + {"OwI16o2i", tag::OwI16o2i}, + {"Owi4o", tag::Owi4o}, + {"Owi8o", tag::Owi8o}, + {"IOhw16o16i", tag::IOhw16o16i}, + {"Ohwi16o", tag::Ohwi16o}, + {"OhwI16o2i", tag::OhwI16o2i}, + {"Ohwi4o", tag::Ohwi4o}, + {"Ohwi8o", tag::Ohwi8o}, + {"OIhw16i16o", tag::OIhw16i16o}, + {"OIhw16i32o", tag::OIhw16i32o}, + {"OIhw16i64o", tag::OIhw16i64o}, + {"OIhw16o16i", tag::OIhw16o16i}, + {"Oihw16o", tag::Oihw16o}, + {"OIhw4i16o4i", tag::OIhw4i16o4i}, + {"OIhw4i32o4i", tag::OIhw4i32o4i}, + {"OIhw4i64o4i", tag::OIhw4i64o4i}, + {"OIhw4i4o", tag::OIhw4i4o}, + {"OIhw4o4i", tag::OIhw4o4i}, + {"Oihw4o", tag::Oihw4o}, + {"OIhw8i16o2i", tag::OIhw8i16o2i}, + {"OIhw8i32o2i", tag::OIhw8i32o2i}, + {"OIhw8i64o2i", tag::OIhw8i64o2i}, + {"OIhw8i8o", tag::OIhw8i8o}, + {"OIhw8o16i2o", tag::OIhw8o16i2o}, + {"OIhw8o8i", tag::OIhw8o8i}, + {"OIhw8o4i", tag::OIhw8o4i}, + {"OIhw2i8o4i", tag::OIhw2i8o4i}, + {"IOdhw16o16i", tag::IOdhw16o16i}, + {"Odhwi16o", tag::Odhwi16o}, + {"OdhwI16o2i", tag::OdhwI16o2i}, + {"Odhwi4o", tag::Odhwi4o}, + {"Odhwi8o", tag::Odhwi8o}, + {"OIdhw16i16o", tag::OIdhw16i16o}, + {"OIdhw16i32o", tag::OIdhw16i32o}, + {"OIdhw16i64o", tag::OIdhw16i64o}, + {"OIdhw16o16i", tag::OIdhw16o16i}, + {"Oidhw16o", tag::Oidhw16o}, + {"OIdhw4i4o", tag::OIdhw4i4o}, + {"OIdhw4o4i", tag::OIdhw4o4i}, + {"Oidhw4o", tag::Oidhw4o}, + {"OIdhw8i16o2i", tag::OIdhw8i16o2i}, + {"OIdhw8i32o2i", tag::OIdhw8i32o2i}, + {"OIdhw8i64o2i", tag::OIdhw8i64o2i}, + {"OIdhw4i16o4i", tag::OIdhw4i16o4i}, + {"OIdhw16i16o4i", tag::OIdhw16i16o4i}, + {"OIdhw16i32o4i", tag::OIdhw16i32o4i}, + {"OIdhw16i48o4i", tag::OIdhw16i48o4i}, + {"OIdhw16i64o4i", tag::OIdhw16i64o4i}, + {"OIdhw16i16o2i", tag::OIdhw16i16o2i}, + {"OIdhw16i32o2i", tag::OIdhw16i32o2i}, + {"OIdhw16i48o2i", tag::OIdhw16i48o2i}, + {"OIdhw16i64o2i", tag::OIdhw16i64o2i}, + {"OIdhw4i32o4i", tag::OIdhw4i32o4i}, + {"OIdhw4i64o4i", tag::OIdhw4i64o4i}, + {"OIdhw2i8o4i", tag::OIdhw2i8o4i}, + {"OIdhw8i8o", tag::OIdhw8i8o}, + {"OIdhw8o8i", tag::OIdhw8o8i}, + {"OIdhw8o4i", tag::OIdhw8o4i}, + {"gIOw16o16i", tag::gIOw16o16i}, + {"gOIw16i16o", tag::gOIw16i16o}, + {"gOIw16o16i", tag::gOIw16o16i}, + {"gOiw16o", tag::gOiw16o}, + {"gOIw4i16o4i", tag::gOIw4i16o4i}, + {"gOIw2i8o4i", tag::gOIw2i8o4i}, + {"gOIw4i4o", tag::gOIw4i4o}, + {"gOIw4o4i", tag::gOIw4o4i}, + {"gOiw4o", tag::gOiw4o}, + {"gOIw8i16o2i", tag::gOIw8i16o2i}, + {"gOIw8i8o", tag::gOIw8i8o}, + {"gOIw8o16i2o", tag::gOIw8o16i2o}, + {"gOIw8o8i", tag::gOIw8o8i}, + {"gOIw8o4i", tag::gOIw8o4i}, + {"gOIw16i16o4i", tag::gOIw16i16o4i}, + {"gOIw16i16o2i", tag::gOIw16i16o2i}, + {"gOIw16o16i2o", tag::gOIw16o16i2o}, + {"gOwi16o", tag::gOwi16o}, + {"gOwI16o2i", tag::gOwI16o2i}, + {"gOwi4o", tag::gOwi4o}, + {"gOwi8o", tag::gOwi8o}, + {"Goiw8g", tag::Goiw8g}, + {"Goiw16g", tag::Goiw16g}, + {"gIOhw16o16i", tag::gIOhw16o16i}, + {"gOhwi16o", tag::gOhwi16o}, + {"gOhwI16o2i", tag::gOhwI16o2i}, + {"gOhwi4o", tag::gOhwi4o}, + {"gOhwi8o", tag::gOhwi8o}, + {"Goihw16g", tag::Goihw16g}, + {"gOIhw16i16o", tag::gOIhw16i16o}, + {"gOIhw16o16i", tag::gOIhw16o16i}, + {"gOihw16o", tag::gOihw16o}, + {"gOIhw4i16o4i", tag::gOIhw4i16o4i}, + {"gOIhw2i8o4i", tag::gOIhw2i8o4i}, + {"gOIhw4i4o", tag::gOIhw4i4o}, + {"gOIhw4o4i", tag::gOIhw4o4i}, + {"gOihw4o", tag::gOihw4o}, + {"Goihw8g", tag::Goihw8g}, + {"gOIhw8i16o2i", tag::gOIhw8i16o2i}, + {"gOIhw8i8o", tag::gOIhw8i8o}, + {"gOIhw8o16i2o", tag::gOIhw8o16i2o}, + {"OIw4o8i8o4i", tag::OIw4o8i8o4i}, + {"OIdhw4o8i8o4i", tag::OIdhw4o8i8o4i}, + {"OIhw4o8i8o4i", tag::OIhw4o8i8o4i}, + {"OIhw2o8i8o2i", tag::OIhw2o8i8o2i}, + {"gOIw4o8i8o4i", tag::gOIw4o8i8o4i}, + {"gOIdhw4o8i8o4i", tag::gOIdhw4o8i8o4i}, + {"gOIhw4o8i8o4i", tag::gOIhw4o8i8o4i}, + {"gOIhw2o8i8o2i", tag::gOIhw2o8i8o2i}, + {"OIhw16i16o4i", tag::OIhw16i16o4i}, + {"OIhw16i32o4i", tag::OIhw16i32o4i}, + {"OIhw16i48o4i", tag::OIhw16i48o4i}, + {"OIhw16i64o4i", tag::OIhw16i64o4i}, + {"OIhw16i16o2i", tag::OIhw16i16o2i}, + {"OIhw16i32o2i", tag::OIhw16i32o2i}, + {"OIhw16i48o2i", tag::OIhw16i48o2i}, + {"OIhw16i64o2i", tag::OIhw16i64o2i}, + {"OIhw16o16i2o", tag::OIhw16o16i2o}, + {"gOIhw16i16o4i", tag::gOIhw16i16o4i}, + {"gOIhw16i16o2i", tag::gOIhw16i16o2i}, + {"gOIhw16o16i2o", tag::gOIhw16o16i2o}, + {"gOIhw8o8i", tag::gOIhw8o8i}, + {"gOIhw8o4i", tag::gOIhw8o4i}, + {"gIOdhw16i16o", tag::gIOdhw16i16o}, + {"gIOdhw16o16i", tag::gIOdhw16o16i}, + {"gOdhwi16o", tag::gOdhwi16o}, + {"gOdhwI16o2i", tag::gOdhwI16o2i}, + {"gOdhwi4o", tag::gOdhwi4o}, + {"gOdhwi8o", tag::gOdhwi8o}, + {"gOIdhw16i16o", tag::gOIdhw16i16o}, + {"gOIdhw16o16i", tag::gOIdhw16o16i}, + {"gOidhw16o", tag::gOidhw16o}, + {"gOIdhw4i4o", tag::gOIdhw4i4o}, + {"gOIdhw4o4i", tag::gOIdhw4o4i}, + {"gOidhw4o", tag::gOidhw4o}, + {"gOIdhw8i16o2i", tag::gOIdhw8i16o2i}, + {"gOIdhw4i16o4i", tag::gOIdhw4i16o4i}, + {"gOIdhw16i16o4i", tag::gOIdhw16i16o4i}, + {"gOIdhw16i16o2i", tag::gOIdhw16i16o2i}, + {"gOIdhw2i8o4i", tag::gOIdhw2i8o4i}, + {"gOIdhw8i8o", tag::gOIdhw8i8o}, + {"gOIdhw8o8i", tag::gOIdhw8o8i}, + {"gOIdhw8o4i", tag::gOIdhw8o4i}, + {"gOIw2i4o2i", tag::gOIw2i4o2i}, + {"gOIhw2i4o2i", tag::gOIhw2i4o2i}, + {"gOIdhw2i4o2i", tag::gOIdhw2i4o2i}, + {"gOIw2o4i2o", tag::gOIw2o4i2o}, + {"gOIhw2o4i2o", tag::gOIhw2o4i2o}, + {"gOIdhw2o4i2o", tag::gOIdhw2o4i2o}, + {"gOIw4i8o2i", tag::gOIw4i8o2i}, + {"gOIhw4i8o2i", tag::gOIhw4i8o2i}, + {"gOIdhw4i8o2i", tag::gOIdhw4i8o2i}, + {"gOIw4o8i2o", tag::gOIw4o8i2o}, + {"gOIhw4o8i2o", tag::gOIhw4o8i2o}, + {"gOIdhw4o8i2o", tag::gOIdhw4o8i2o}, + {"ldOi32o", tag::ldOi32o}, + {"ldOI32o4i", tag::ldOI32o4i}, + {"ldgOi32o", tag::ldgOi32o}, + {"ldgOI32o2i", tag::ldgOI32o2i}, + {"ldgOI32o4i", tag::ldgOI32o4i}, + {"OwI16o4i", tag::OwI16o4i}, + {"OhwI16o4i", tag::OhwI16o4i}, + {"gOwI16o4i", tag::gOwI16o4i}, + {"gOhwI16o4i", tag::gOhwI16o4i}, + {"OdhwI16o4i", tag::OdhwI16o4i}, + {"gOdhwI16o4i", tag::gOdhwI16o4i}, + {"Owi32o", tag::Owi32o}, + {"OwI32o2i", tag::OwI32o2i}, + {"OwI32o4i", tag::OwI32o4i}, + {"Owi48o", tag::Owi48o}, + {"OwI48o2i", tag::OwI48o2i}, + {"OwI48o4i", tag::OwI48o4i}, + {"Owi64o", tag::Owi64o}, + {"OwI64o2i", tag::OwI64o2i}, + {"OwI64o4i", tag::OwI64o4i}, + {"wIo2i", tag::wIo2i}, + {"wIo4i", tag::wIo4i}, + {"gOwi32o", tag::gOwi32o}, + {"gOwI32o2i", tag::gOwI32o2i}, + {"gOwI32o4i", tag::gOwI32o4i}, + {"gOwi48o", tag::gOwi48o}, + {"gOwI48o2i", tag::gOwI48o2i}, + {"gOwI48o4i", tag::gOwI48o4i}, + {"gOwi64o", tag::gOwi64o}, + {"gOwI64o2i", tag::gOwI64o2i}, + {"gOwI64o4i", tag::gOwI64o4i}, + {"gwio", tag::gwio}, + {"gwIo2i", tag::gwIo2i}, + {"gwIo4i", tag::gwIo4i}, + {"OhwI32o", tag::OhwI32o}, + {"OhwI32o2i", tag::OhwI32o2i}, + {"OhwI32o4i", tag::OhwI32o4i}, + {"Ohwi48o", tag::Ohwi48o}, + {"OhwI48o2i", tag::OhwI48o2i}, + {"OhwI48o4i", tag::OhwI48o4i}, + {"Ohwi64o", tag::Ohwi64o}, + {"OhwI64o2i", tag::OhwI64o2i}, + {"OhwI64o4i", tag::OhwI64o4i}, + {"hwIo2i", tag::hwIo2i}, + {"hwIo4i", tag::hwIo4i}, + {"gOhwI32o", tag::gOhwI32o}, + {"gOhwI32o2i", tag::gOhwI32o2i}, + {"gOhwI32o4i", tag::gOhwI32o4i}, + {"gOhwi48o", tag::gOhwi48o}, + {"gOhwI48o2i", tag::gOhwI48o2i}, + {"gOhwI48o4i", tag::gOhwI48o4i}, + {"gOhwi64o", tag::gOhwi64o}, + {"gOhwI64o2i", tag::gOhwI64o2i}, + {"gOhwI64o4i", tag::gOhwI64o4i}, + {"ghwio", tag::ghwio}, + {"ghwIo2i", tag::ghwIo2i}, + {"ghwIo4i", tag::ghwIo4i}, + {"Odhwi32o", tag::Odhwi32o}, + {"OdhwI32o2i", tag::OdhwI32o2i}, + {"OdhwI32o4i", tag::OdhwI32o4i}, + {"Odhwi48o", tag::Odhwi48o}, + {"OdhwI48o2i", tag::OdhwI48o2i}, + {"OdhwI48o4i", tag::OdhwI48o4i}, + {"Odhwi64o", tag::Odhwi64o}, + {"OdhwI64o2i", tag::OdhwI64o2i}, + {"OdhwI64o4i", tag::OdhwI64o4i}, + {"dhwIo2i", tag::dhwIo2i}, + {"dhwIo4i", tag::dhwIo4i}, + {"gOdhwi32o", tag::gOdhwi32o}, + {"gOdhwI32o2i", tag::gOdhwI32o2i}, + {"gOdhwI32o4i", tag::gOdhwI32o4i}, + {"gOdhwi48o", tag::gOdhwi48o}, + {"gOdhwI48o2i", tag::gOdhwI48o2i}, + {"gOdhwI48o4i", tag::gOdhwI48o4i}, + {"gOdhwi64o", tag::gOdhwi64o}, + {"gOdhwI64o2i", tag::gOdhwI64o2i}, + {"gOdhwI64o4i", tag::gOdhwI64o4i}, + {"gdhwio", tag::gdhwio}, + {"gdhwIo2i", tag::gdhwIo2i}, + {"gdhwIo4i", tag::gdhwIo4i}, + {"ldIo32i", tag::ldIo32i}, + {"ldgIo32i", tag::ldgIo32i}, + {"ldgIO32i2o", tag::ldgIO32i2o}, + {"nCdhw32c", tag::nCdhw32c}, + {"nChw32c", tag::nChw32c}, + {"nCw32c", tag::nCw32c}, + {"NCw32n16c", tag::NCw32n16c}, + {"NChw32n16c", tag::NChw32n16c}, + {"NCdhw32n16c", tag::NCdhw32n16c}, + {"NCw32n32c", tag::NCw32n32c}, + {"OI16i16o4i", tag::OI16i16o4i}, + {"IOw8o16i2o", tag::IOw8o16i2o}, + {"IOhw8o16i2o", tag::IOhw8o16i2o}, + {"Owhi16o", tag::Owhi16o}, + {"OIdhw8o16i2o", tag::OIdhw8o16i2o}, + {"IOdhw8o16i2o", tag::IOdhw8o16i2o}, + {"Goiw4g", tag::Goiw4g}, + {"gIOw8o16i2o", tag::gIOw8o16i2o}, + {"Goiw32g", tag::Goiw32g}, + {"Goihw4g", tag::Goihw4g}, + {"gIOhw8o16i2o", tag::gIOhw8o16i2o}, + {"Goihw32g", tag::Goihw32g}, + {"gOwhi16o", tag::gOwhi16o}, + {"IOw4i8o8i4o", tag::IOw4i8o8i4o}, + {"IOhw4i8o8i4o", tag::IOhw4i8o8i4o}, + {"IOdhw4i8o8i4o", tag::IOdhw4i8o8i4o}, + {"gIOw4i8o8i4o", tag::gIOw4i8o8i4o}, + {"gIOhw4i8o8i4o", tag::gIOhw4i8o8i4o}, + {"gIOdhw4i8o8i4o", tag::gIOdhw4i8o8i4o}, + {"gOIdhw8o16i2o", tag::gOIdhw8o16i2o}, + {"gIOdhw8o16i2o", tag::gIOdhw8o16i2o}, + {"Goidhw32g", tag::Goidhw32g}, + {"OI16i32o4i", tag::OI16i32o4i}, + {"OI16i48o4i", tag::OI16i48o4i}, + {"OI16i64o4i", tag::OI16i64o4i}, + {"OI16i16o2i", tag::OI16i16o2i}, + {"OI16i32o2i", tag::OI16i32o2i}, + {"OI16i48o2i", tag::OI16i48o2i}, + {"OI16i64o2i", tag::OI16i64o2i}, + {"OwI16i16o2i", tag::OwI16i16o2i}, + {"gOwI16i16o2i", tag::gOwI16i16o2i}, + {"OhwI16i16o2i", tag::OhwI16i16o2i}, + {"gOhwI16i16o2i", tag::gOhwI16i16o2i}, + {"OdhwI16i16o2i", tag::OdhwI16i16o2i}, + {"gOdhwI16i16o2i", tag::gOdhwI16i16o2i}, + {"OwI16i16o4i", tag::OwI16i16o4i}, + {"gOwI16i16o4i", tag::gOwI16i16o4i}, + {"OhwI16i16o4i", tag::OhwI16i16o4i}, + {"gOhwI16i16o4i", tag::gOhwI16i16o4i}, + {"OdhwI16i16o4i", tag::OdhwI16i16o4i}, + {"gOdhwI16i16o4i", tag::gOdhwI16i16o4i}, + {"OwI16i32o2i", tag::OwI16i32o2i}, + {"OwI16i32o4i", tag::OwI16i32o4i}, + {"OwI16i48o2i", tag::OwI16i48o2i}, + {"OwI16i48o4i", tag::OwI16i48o4i}, + {"OwI16i64o2i", tag::OwI16i64o2i}, + {"OwI16i64o4i", tag::OwI16i64o4i}, + {"gOwI16i32o2i", tag::gOwI16i32o2i}, + {"gOwI16i32o4i", tag::gOwI16i32o4i}, + {"gOwI16i48o2i", tag::gOwI16i48o2i}, + {"gOwI16i48o4i", tag::gOwI16i48o4i}, + {"gOwI16i64o2i", tag::gOwI16i64o2i}, + {"gOwI16i64o4i", tag::gOwI16i64o4i}, + {"OhwI16i32o2i", tag::OhwI16i32o2i}, + {"OhwI16i32o4i", tag::OhwI16i32o4i}, + {"OhwI16i48o2i", tag::OhwI16i48o2i}, + {"OhwI16i48o4i", tag::OhwI16i48o4i}, + {"OhwI16i64o2i", tag::OhwI16i64o2i}, + {"OhwI16i64o4i", tag::OhwI16i64o4i}, + {"gOhwI16i32o2i", tag::gOhwI16i32o2i}, + {"gOhwI16i32o4i", tag::gOhwI16i32o4i}, + {"gOhwI16i48o2i", tag::gOhwI16i48o2i}, + {"gOhwI16i48o4i", tag::gOhwI16i48o4i}, + {"gOhwI16i64o2i", tag::gOhwI16i64o2i}, + {"gOhwI16i64o4i", tag::gOhwI16i64o4i}, + {"OdhwI16i32o2i", tag::OdhwI16i32o2i}, + {"OdhwI16i32o4i", tag::OdhwI16i32o4i}, + {"OdhwI16i48o2i", tag::OdhwI16i48o2i}, + {"OdhwI16i48o4i", tag::OdhwI16i48o4i}, + {"OdhwI16i64o2i", tag::OdhwI16i64o2i}, + {"OdhwI16i64o4i", tag::OdhwI16i64o4i}, + {"gOdhwI16i32o2i", tag::gOdhwI16i32o2i}, + {"gOdhwI16i32o4i", tag::gOdhwI16i32o4i}, + {"gOdhwI16i48o2i", tag::gOdhwI16i48o2i}, + {"gOdhwI16i48o4i", tag::gOdhwI16i48o4i}, + {"gOdhwI16i64o2i", tag::gOdhwI16i64o2i}, + {"gOdhwI16i64o4i", tag::gOdhwI16i64o4i}, + {"hwioG16g", tag::hwioG16g}, + {"NCdhw40n32c", tag::NCdhw40n32c}, + {"NChw40n32c", tag::NChw40n32c}, + {"NCw40n32c", tag::NCw40n32c}, + {"OIdhw4o8i8o2i", tag::OIdhw4o8i8o2i}, + {"OIhw4o8i8o2i", tag::OIhw4o8i8o2i}, + {"OIw4o8i8o2i", tag::OIw4o8i8o2i}, + {"gOIdhw4o8i8o2i", tag::gOIdhw4o8i8o2i}, + {"gOIhw4o8i8o2i", tag::gOIhw4o8i8o2i}, + {"gOIw4o8i8o2i", tag::gOIw4o8i8o2i}, + {"IOdhw4i8o8i2o", tag::IOdhw4i8o8i2o}, + {"IOhw4i8o8i2o", tag::IOhw4i8o8i2o}, + {"IOw4i8o8i2o", tag::IOw4i8o8i2o}, + {"gIOdhw4i8o8i2o", tag::gIOdhw4i8o8i2o}, + {"gIOhw4i8o8i2o", tag::gIOhw4i8o8i2o}, + {"gIOw4i8o8i2o", tag::gIOw4i8o8i2o}, + {"NCdhw40n16c", tag::NCdhw40n16c}, + {"NCw40n16c", tag::NCw40n16c}, + {"NChw40n16c", tag::NChw40n16c}, + {"NCw2c32n8c", tag::NCw2c32n8c}, + {"NChw2c32n8c", tag::NChw2c32n8c}, + {"NCdhw2c32n8c", tag::NCdhw2c32n8c}, + {"OIw2i8o16i4o", tag::OIw2i8o16i4o}, + {"OIhw2i8o16i4o", tag::OIhw2i8o16i4o}, + {"OIdhw2i8o16i4o", tag::OIdhw2i8o16i4o}, + {"OIw2o8i16o4i", tag::OIw2o8i16o4i}, + {"OIw2o8i16o2i", tag::OIw2o8i16o2i}, + {"IOw2i8o16i4o", tag::IOw2i8o16i4o}, + {"IOw2i8o16i2o", tag::IOw2i8o16i2o}, + {"OIhw2o8i16o4i", tag::OIhw2o8i16o4i}, + {"OIhw2o8i16o2i", tag::OIhw2o8i16o2i}, + {"IOhw2i8o16i4o", tag::IOhw2i8o16i4o}, + {"IOhw2i8o16i2o", tag::IOhw2i8o16i2o}, + {"OIdhw2o8i16o4i", tag::OIdhw2o8i16o4i}, + {"OIdhw2o8i16o2i", tag::OIdhw2o8i16o2i}, + {"IOdhw2i8o16i4o", tag::IOdhw2i8o16i4o}, + {"IOdhw2i8o16i2o", tag::IOdhw2i8o16i2o}, + {"gOIw2o8i16o2i", tag::gOIw2o8i16o2i}, + {"gIOw2i8o16i2o", tag::gIOw2i8o16i2o}, + {"gIOhw2i8o16i2o", tag::gIOhw2i8o16i2o}, + {"gIOdhw2i8o16i2o", tag::gIOdhw2i8o16i2o}, + {"gOIhw2o8i16o2i", tag::gOIhw2o8i16o2i}, + {"gOIdhw2o8i16o2i", tag::gOIdhw2o8i16o2i}, + {"gOIw2o8i16o4i", tag::gOIw2o8i16o4i}, + {"gOIhw2o8i16o4i", tag::gOIhw2o8i16o4i}}; + std::string key = ""; + for (const auto& c : layout) { + if (std::isalpha(c, std::locale("C"))) { + char lower_c = std::tolower(c); + if (std::isupper(c) && (layout.find(lower_c) != std::string::npos)) { + key.push_back(c); + } else { + key.push_back(lower_c); + } + } else if (std::isdigit(c)) { + key.push_back(c); + } else { + LOG(FATAL) << "invalid char '" << c << "' in " << layout << std::endl; + } + } + if (str2tag.count(key) == 0) { + LOG(WARNING) << "convert unregistered layout '" << key << "' to tag::any"; + return tag::any; + } else { + return str2tag.at(key); + } + } std::map elt_name2algo{ {"abs", dnnl::algorithm::eltwise_abs}, @@ -106,62 +602,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {"clip", dnnl::algorithm::eltwise_clip}, }; - std::map layout_dict{ - {"", tag::any}, - {"NCW", tag::ncw}, - {"NWC", tag::nwc}, - {"OIW", tag::oiw}, - {"GOIW", tag::goiw}, - {"NCHW", tag::nchw}, - {"NHWC", tag::nhwc}, - {"OIHW", tag::oihw}, - {"GOIHW", tag::goihw}, - {"NCDHW", tag::ncdhw}, - {"NDHWC", tag::ndhwc}, - {"OIDHW", tag::oidhw}, - {"GOIDHW", tag::goidhw}, - {"IOHW", tag::iohw}, - {"GIOHW", tag::giohw}, - {"IODHW", tag::iodhw}, - {"GIODHW", tag::giodhw}, - - // Blocking layout. - {"NCW8c", tag::nCw8c}, - {"NCW16c", tag::nCw16c}, - {"OIW16i16o", tag::OIw8i8o}, - {"OIW16i16o", tag::OIw16i16o}, - {"OWI8o", tag::Owi8o}, - {"OWI16o", tag::Owi16o}, - {"NCHW4c", tag::nChw4c}, - {"NCHW8c", tag::nChw8c}, - {"NCHW16c", tag::nChw16c}, - {"OIHW8i8o", tag::OIhw8i8o}, - {"IOHW8i8o", tag::any}, - {"OIHW16i16o", tag::OIhw16i16o}, - {"IOHW16i16o", tag::IOhw16i16o}, - {"GOIHW4i4o", tag::gOIhw4i4o}, - {"GOIHW8i8o", tag::gOIhw8i8o}, - {"GOIHW16i16o", tag::gOIhw16i16o}, - {"OHWI8o", tag::Ohwi8o}, - {"OHWI16o", tag::Ohwi16o}, - {"OHWI32o", tag::Ohwi32o}, - {"OHWI48o", tag::Ohwi48o}, - {"OHWI64o", tag::Ohwi64o}, - {"GOIHW8g", tag::Goihw8g}, - {"GOIHW16g", tag::Goihw16g}, - {"NCDHW8c", tag::nCdhw8c}, - {"NCDHW16c", tag::nCdhw16c}, - {"OIDHW16i16o", tag::OIdhw16i16o}, - {"IODHW16i16o", tag::IOdhw16i16o}, - {"OIDHW8i8o", tag::OIdhw8i8o}, - {"IODHW8i8o", tag::any}, - {"ODHWI8o", tag::Odhwi8o}, - {"ODHWI16o", tag::Odhwi16o}, - {"ODHWI32o", tag::Odhwi32o}, - {"ODHWI48o", tag::Odhwi48o}, - {"ODHWI64o", tag::Odhwi64o}, - }; - bool ParsingOpName(const std::string op_name, dnnl::primitive_attr attr) { // Define RegExp. std::regex bias_add_pat(".*_bias.*"); @@ -202,12 +642,13 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } // Push the correct shapes of each axis into the output_dims for (auto a : axis) { - dnnl::memory::dim shape = 1; if (layout.find(a) != std::string::npos) { - shape *= input_dims[layout.find(a)]; + dnnl::memory::dim shape = input_dims[layout.find(a)]; char lower_a = std::tolower(a); - if (layout.find(lower_a) != std::string::npos) { - shape *= input_dims[layout.find(lower_a)]; + for (size_t i = 0; i < layout.size(); ++i) { + if (lower_a == layout[i]) { + shape *= input_dims[i]; + } } out_dims.push_back(shape); } @@ -238,6 +679,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { return out_dims; } + // Build up the engine based on the input graph. void BuildEngine() { engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0); stream_ = dnnl::stream(engine_); @@ -301,11 +743,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // has not yet been bound to the other DNNL memory; otherwise it may have memory leak. ICHECK_EQ(entry_out_mem_.count(eid), 0); - // TODO(@comanic): Support other data types (i.e., int8). - auto data_node = nodes_[entry.id_]; - auto dltype = data_node.GetOpDataType()[entry.index_]; - ICHECK_EQ(dltype.bits, 32); - entry_out_mem_[eid] = {mem, offset}; return entry_out_mem_[eid].first; } @@ -338,17 +775,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::string data_layout = node.GetAttr>("data_layout")[0]; std::string kernel_layout = node.GetAttr>("kernel_layout")[0]; - // Check layout. - if (layout_dict.find(data_layout) == layout_dict.end()) { - LOG(FATAL) << "Unsupported data layout for conv: " << data_layout; - } - - if (layout_dict.find(kernel_layout) == layout_dict.end()) { - layout_dict.insert({kernel_layout, tag::any}); - LOG(WARNING) << "Unregistered kernel layout for conv: " << kernel_layout - << ", transfer to tag::any"; - } - // Memory shapes. dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout); dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout); @@ -360,6 +786,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { dnnl::memory::dims dst_dims = src_dims; dst_dims[1] = channels; weights_dims_[0] = channels; + weights_dims_[1] = src_dims[1]; for (size_t i = 2; i < src_dims.size(); i++) { dnnl::memory::dim K = weights_dims_[i]; dnnl::memory::dim S = strides_dims[i - 2]; @@ -380,10 +807,11 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } // Memory descriptions. - auto conv_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[data_layout]); - auto conv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, layout_dict[kernel_layout]); - auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any); - auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any); + auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); + auto conv_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(data_layout)); + auto conv_weights_md = dnnl::memory::desc(weights_dims, dtype, layout2tag(kernel_layout)); + auto conv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::any); + auto conv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any); // Conv description. auto conv_desc = @@ -413,7 +841,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto conv_dst_memory = BindDNNLMemory(out_entry, conv_prim_desc.dst_desc()); // Bias memory. - auto conv_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_); + auto conv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x}, engine_); if (has_bias) { auto bias_entry = node.GetInputs()[2]; BindDNNLMemory(bias_entry, conv_bias_memory); @@ -461,17 +889,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::string data_layout = node.GetAttr>("data_layout")[0]; std::string kernel_layout = node.GetAttr>("kernel_layout")[0]; - // Check layout. - if (layout_dict.find(data_layout) == layout_dict.end()) { - LOG(FATAL) << "Unsupported data layout for deconv: " << data_layout; - } - - if (layout_dict.find(kernel_layout) == layout_dict.end()) { - layout_dict.insert({kernel_layout, tag::any}); - LOG(WARNING) << "Unregistered kernel layout for deconv: " << data_layout - << ", transfer to tag::any"; - } - // Memory shapes. dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout); dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout); @@ -482,6 +899,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { kernel_layout.replace(kernel_layout.find("OI"), 2, "IO"); } } + weights_dims_[0] = channels; + weights_dims_[1] = src_dims[1]; dnnl::memory::dims bias_dims = {channels}; dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides); dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true); @@ -508,10 +927,11 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } // Memory descriptions. - auto deconv_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[data_layout]); - auto deconv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, layout_dict[kernel_layout]); - auto deconv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any); - auto deconv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any); + auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); + auto deconv_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(data_layout)); + auto deconv_weights_md = dnnl::memory::desc(weights_dims, dtype, layout2tag(kernel_layout)); + auto deconv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::x); + auto deconv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any); // Transposed covn2d description. auto deconv_desc = @@ -541,7 +961,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto deconv_dst_memory = BindDNNLMemory(out_entry, deconv_prim_desc.dst_desc()); // Bias memory. - auto deconv_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_); + auto deconv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x}, engine_); if (has_bias) { auto bias_entry = node.GetInputs()[2]; BindDNNLMemory(bias_entry, deconv_bias_memory); @@ -581,10 +1001,12 @@ class DNNLJSONRuntime : public JSONRuntimeBase { dnnl::memory::dims out_dims = out_shape; // Memory descriptions. - auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::nc}); - auto weight_md = dnnl::memory::desc({weight_dims, dt::f32, tag::nc}); - auto bias_md = dnnl::memory::desc({bias_dims, dt::f32, tag::x}); - auto dst_md = dnnl::memory::desc({out_dims, dt::f32, tag::nc}); + auto dl_dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]; + auto dtype = dtype_dl2dnnl(dl_dtype); + auto data_md = dnnl::memory::desc({data_dims, dtype, tag::nc}); + auto weight_md = dnnl::memory::desc({weight_dims, dtype, tag::nc}); + auto bias_md = dnnl::memory::desc({bias_dims, dtype, tag::x}); + auto dst_md = dnnl::memory::desc({out_dims, dtype, tag::nc}); // Dense description. auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md, @@ -607,7 +1029,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { BindDNNLMemory(bias_entry, bias_memory); } else { float bias[OC] = {0}; - write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float)); + write_to_dnnl_memory(bias, bias_memory, OC * ((dl_dtype.bits + 7) / 8)); } // Output memory. @@ -632,7 +1054,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { float epsilon = std::stof(node.GetAttr>("epsilon")[0]); // Memory description. - dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32); + auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); + dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype); // BN description. auto bn_desc = dnnl::batch_normalization_forward::desc( @@ -679,11 +1102,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::vector str_dilates = node.GetAttr>("dilation"); std::string layout = node.GetAttr>("layout")[0]; - // Check layout. - if (layout_dict.find(layout) == layout_dict.end()) { - LOG(FATAL) << "Unsupported layout for pooling: " << layout; - } - // Attributes related to AvgPool if (algo == dnnl::algorithm::pooling_avg) { int int_countpad = std::stoi(node.GetAttr>("count_include_pad")[0]); @@ -701,8 +1119,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase { dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r); // Memory descriptions. - auto pool_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[layout]); - auto pool_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any); + auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); + auto pool_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(layout)); + auto pool_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any); // Pooling description. auto pool_desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_inference, algo, @@ -729,7 +1148,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto data_entry = node.GetInputs()[0]; dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32); + auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); + dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype); float alpha = 0., beta = 0.; if (op_name == "clip") { alpha = std::stof(node.GetAttr>("a_min")[0]); @@ -762,7 +1182,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (axis < 0) { axis = shape.size() + axis; } - dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32); + auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); + dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype); auto softmax_desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, data_md, axis); @@ -790,7 +1211,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { ICHECK_EQ(node.GetInputs().size(), 2U); for (auto entry : node.GetInputs()) { auto data_shape = nodes_[entry.id_].GetOpShape()[entry.index_]; - dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32); + auto dtype = dtype_dl2dnnl(nodes_[entry.id_].GetOpDataType()[entry.index_]); + dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype); data_dims.push_back(data_shape); data_mds.push_back(data_md); diff --git a/src/runtime/contrib/dnnl/dnnl_utils.cc b/src/runtime/contrib/dnnl/dnnl_utils.cc new file mode 100644 index 000000000000..7e79f1c939cf --- /dev/null +++ b/src/runtime/contrib/dnnl/dnnl_utils.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/dnnl/dnnl_utils.cc + */ + +#include "dnnl_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { +using dt = dnnl::memory::data_type; +dt dtype_dl2dnnl(DLDataType dltype) { + dt dnnl_type = dt::undef; + if (dltype.code == DataType::TypeCode::kFloat) { + if (dltype.bits == 16) { + dnnl_type = dt::f16; + } else if (dltype.bits == 32) { + dnnl_type = dt::f32; + } + } else if (dltype.code == DataType::TypeCode::kBFloat && dltype.bits == 16) { + dnnl_type = dt::bf16; + } else if (dltype.code == DataType::TypeCode::kInt) { + if (dltype.bits == 8) { + dnnl_type = dt::s8; + } else if (dltype.bits == 32) { + dnnl_type = dt::s32; + } + } else if (dltype.code == DataType::TypeCode::kUInt && dltype.bits == 8) { + dnnl_type = dt::u8; + } + if (dnnl_type == dt::undef) { + LOG_ERROR << "unsupported datatype: code=" << dltype.code << ", bits=" << dltype.bits; + } + return dnnl_type; +} +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/dnnl/dnnl_utils.h b/src/runtime/contrib/dnnl/dnnl_utils.h new file mode 100644 index 000000000000..4fb236f96f8b --- /dev/null +++ b/src/runtime/contrib/dnnl/dnnl_utils.h @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/dnnl/dnnl_utils.h + * \brief utils for DNNL. + */ + +#ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_ +#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_ + +#include + +#include "dnnl.hpp" + +namespace tvm { +namespace runtime { +namespace contrib { + +/*! + * \brief Convert a DLPack data type to a DNNL data type. + * \param dltype The DLPack data type. + * \return The corresponding DNNL data type. + */ +dnnl::memory::data_type dtype_dl2dnnl(DLDataType dltype); + +} // namespace contrib +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_ diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 5e3ba83ce000..f784f7b49aac 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -205,10 +205,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) inline bool GetStoreRule(Array* index_rule, Array* shape_rule, const Layout& src_layout, const Layout& dst_layout) { - if (!src_layout.defined() || src_layout.name().empty() || !dst_layout.defined() || - dst_layout.name().empty()) { + if (!src_layout.defined() || src_layout.name().empty()) { + LOG(WARNING) << "src layout '" << src_layout.name() << "' is invalid."; return false; } + if (!dst_layout.defined() || dst_layout.name().empty()) { + LOG(WARNING) << "dst layout '" << dst_layout.name() << "' is invalid."; + return false; + } + for (size_t i = 0; i < dst_layout.ndim(); ++i) { const auto& store_axis = dst_layout[i]; const IterVar& store_axis_impl = dst_layout->axes[i]; @@ -237,7 +242,8 @@ inline bool GetStoreRule(Array* index_rule, Array* shape_rul } } if (tir::is_zero(index_store)) { - // Not convertible + LOG(WARNING) << "layout '" << src_layout.name() << "'-->'" << dst_layout.name() + << "' is not convertible."; return false; } diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 5baf6e06d347..fecd776d7065 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -37,6 +37,8 @@ ids=["compile", "run"], ) +bf16_supported = "avx512" in open("/proc/cpuinfo", "r").read() + def partition_for_dnnl(mod, params=None, alter_layout=True): """Partition the graph greedily offloading supported operators to DNNL. @@ -109,7 +111,10 @@ def partition_for_dnnl(mod, params=None, alter_layout=True): def vmobj_to_list(o): if isinstance(o, tvm.nd.NDArray): - return [o.numpy()] + o_np = o.numpy() + if o_np.dtype == np.uint16: + o_np = np.left_shift(o_np.astype("uint32"), 16).view("