diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index f6a1c3b79080..a2417f012ea4 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -32,7 +32,12 @@ #include "../json/json_node.h" #include "../json/json_runtime.h" -#include "dnnl.hpp" + +// TODO(@apeskov): Have to mute warning from dnnl headers. +// -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command +#include + +#include "dnnl_tensor_requisite.h" #include "dnnl_utils.h" namespace tvm { @@ -43,552 +48,82 @@ using namespace tvm::runtime; using namespace tvm::runtime::json; class DNNLJSONRuntime : public JSONRuntimeBase { - using tag = dnnl::memory::format_tag; - using dt = dnnl::memory::data_type; - public: DNNLJSONRuntime(const std::string& symbol_name, const std::string& graph_json, const Array const_names) - : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + : JSONRuntimeBase(symbol_name, graph_json, const_names), + next_unique_eid_offset_(data_entry_.size()), + run_arg_eid_(input_var_eid_) { + for (const auto e : outputs_) run_arg_eid_.push_back(EntryID(e)); + } - const char* type_key() const { return "dnnl_json"; } + const char* type_key() const override { return "dnnl_json"; } void Init(const Array& consts) override { - BuildEngine(); - ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; // Setup constants entries for weights. SetupConstants(consts); + BuildEngine(); } - void Run() override { - // Fill in the input buffers. - for (size_t i = 0; i < input_nodes_.size(); ++i) { - auto eid = EntryID(input_nodes_[i], 0); - size_t offset_in_bytes = - entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) / 8); - size_t buffer_size = GetDataSize(*data_entry_[eid]); - write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size, - offset_in_bytes); - } + /* Unused stub implementation */ + void Run() override { LOG(FATAL) << "Unreachable code"; } - // Invoke the engine through intepreting the stream. - for (size_t i = 0; i < net_.size(); ++i) { - net_.at(i).execute(stream_, net_args_.at(i)); - } - stream_.wait(); - - // Read output buffers. - for (size_t i = 0; i < outputs_.size(); ++i) { - auto eid = EntryID(outputs_[i]); - size_t offset_in_bytes = - entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) / 8); - size_t buffer_size = GetDataSize(*data_entry_[eid]); - read_from_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size, - offset_in_bytes); + /* Thread safe implementation of Run. Keep runtime instance immutable */ + void Run(const TVMArgs& args) const { + auto arg_data_provider = makeIODataProvider(args); + auto mem_solver = tensor_registry_.MakeSolver(arg_data_provider); + // Execute primitives one by one + for (const auto& act : net_) { + auto prim = std::get<0>(act); + auto arg_reqs = std::get<1>(act); + + // Find proper dnnl::memory buffers + std::unordered_map mem_args; + for (const auto& kvp : arg_reqs) mem_args[kvp.first] = mem_solver(kvp.second); + + prim.execute(stream_, mem_args); } } - private: - tag layout2tag(std::string layout) { - static const std::map str2tag = {{"nc", tag::nc}, - {"cn", tag::cn}, - {"tn", tag::tn}, - {"nt", tag::nt}, - {"ncw", tag::ncw}, - {"nwc", tag::nwc}, - {"nchw", tag::nchw}, - {"nhwc", tag::nhwc}, - {"chwn", tag::chwn}, - {"ncdhw", tag::ncdhw}, - {"ndhwc", tag::ndhwc}, - {"oi", tag::oi}, - {"io", tag::io}, - {"oiw", tag::oiw}, - {"owi", tag::owi}, - {"wio", tag::wio}, - {"iwo", tag::iwo}, - {"oihw", tag::oihw}, - {"hwio", tag::hwio}, - {"ohwi", tag::ohwi}, - {"ihwo", tag::ihwo}, - {"iohw", tag::iohw}, - {"oidhw", tag::oidhw}, - {"dhwio", tag::dhwio}, - {"odhwi", tag::odhwi}, - {"iodhw", tag::iodhw}, - {"idhwo", tag::idhwo}, - {"goiw", tag::goiw}, - {"gowi", tag::gowi}, - {"wigo", tag::wigo}, - {"gohwi", tag::gohwi}, - {"goihw", tag::goihw}, - {"hwigo", tag::hwigo}, - {"giohw", tag::giohw}, - {"goidhw", tag::goidhw}, - {"giodhw", tag::giodhw}, - {"godhwi", tag::godhwi}, - {"dhwigo", tag::dhwigo}, - {"tnc", tag::tnc}, - {"ntc", tag::ntc}, - {"ldnc", tag::ldnc}, - {"ldigo", tag::ldigo}, - {"ldgoi", tag::ldgoi}, - {"ldio", tag::ldio}, - {"ldoi", tag::ldoi}, - {"ldgo", tag::ldgo}, - {"nCdhw16c", tag::nCdhw16c}, - {"nCdhw4c", tag::nCdhw4c}, - {"nCdhw8c", tag::nCdhw8c}, - {"nChw16c", tag::nChw16c}, - {"nChw4c", tag::nChw4c}, - {"nChw8c", tag::nChw8c}, - {"nCw16c", tag::nCw16c}, - {"nCw4c", tag::nCw4c}, - {"nCw8c", tag::nCw8c}, - {"NCw16n16c", tag::NCw16n16c}, - {"NChw16n16c", tag::NChw16n16c}, - {"NCdhw16n16c", tag::NCdhw16n16c}, - {"NCdhw32n32c", tag::NCdhw32n32c}, - {"NChw32n32c", tag::NChw32n32c}, - {"IOhw16i16o", tag::IOhw16i16o}, - {"OI16i16o", tag::OI16i16o}, - {"OI16i32o", tag::OI16i32o}, - {"OI16i64o", tag::OI16i64o}, - {"OI8i16o2i", tag::OI8i16o2i}, - {"OI8i32o2i", tag::OI8i32o2i}, - {"OI8i64o2i", tag::OI8i64o2i}, - {"OI4i16o4i", tag::OI4i16o4i}, - {"OI4i32o4i", tag::OI4i32o4i}, - {"OI4i64o4i", tag::OI4i64o4i}, - {"Ohwi32o", tag::Ohwi32o}, - {"IOdhw16i16o", tag::IOdhw16i16o}, - {"gIOhw16i16o", tag::gIOhw16i16o}, - {"gOhwi32o", tag::gOhwi32o}, - {"Goidhw16g", tag::Goidhw16g}, - {"IOw16o16i", tag::IOw16o16i}, - {"OIw16i16o", tag::OIw16i16o}, - {"OIw16i32o", tag::OIw16i32o}, - {"OIw16i64o", tag::OIw16i64o}, - {"IOw16i16o", tag::IOw16i16o}, - {"gIOw16i16o", tag::gIOw16i16o}, - {"OIw16o16i", tag::OIw16o16i}, - {"Oiw16o", tag::Oiw16o}, - {"OIw4i16o4i", tag::OIw4i16o4i}, - {"OIw4i32o4i", tag::OIw4i32o4i}, - {"OIw4i64o4i", tag::OIw4i64o4i}, - {"OIw2i8o4i", tag::OIw2i8o4i}, - {"OIw4i4o", tag::OIw4i4o}, - {"OIw4o4i", tag::OIw4o4i}, - {"Oiw4o", tag::Oiw4o}, - {"OIw8i16o2i", tag::OIw8i16o2i}, - {"OIw8i32o2i", tag::OIw8i32o2i}, - {"OIw8i64o2i", tag::OIw8i64o2i}, - {"OIw8i8o", tag::OIw8i8o}, - {"OIw8o16i2o", tag::OIw8o16i2o}, - {"OIw8o8i", tag::OIw8o8i}, - {"OIw8o4i", tag::OIw8o4i}, - {"OIw16i16o4i", tag::OIw16i16o4i}, - {"OIw16i32o4i", tag::OIw16i32o4i}, - {"OIw16i48o4i", tag::OIw16i48o4i}, - {"OIw16i64o4i", tag::OIw16i64o4i}, - {"OIw16i16o2i", tag::OIw16i16o2i}, - {"OIw16i32o2i", tag::OIw16i32o2i}, - {"OIw16i48o2i", tag::OIw16i48o2i}, - {"OIw16i64o2i", tag::OIw16i64o2i}, - {"OIw16o16i2o", tag::OIw16o16i2o}, - {"Owi16o", tag::Owi16o}, - {"OwI16o2i", tag::OwI16o2i}, - {"Owi4o", tag::Owi4o}, - {"Owi8o", tag::Owi8o}, - {"IOhw16o16i", tag::IOhw16o16i}, - {"Ohwi16o", tag::Ohwi16o}, - {"OhwI16o2i", tag::OhwI16o2i}, - {"Ohwi4o", tag::Ohwi4o}, - {"Ohwi8o", tag::Ohwi8o}, - {"OIhw16i16o", tag::OIhw16i16o}, - {"OIhw16i32o", tag::OIhw16i32o}, - {"OIhw16i64o", tag::OIhw16i64o}, - {"OIhw16o16i", tag::OIhw16o16i}, - {"Oihw16o", tag::Oihw16o}, - {"OIhw4i16o4i", tag::OIhw4i16o4i}, - {"OIhw4i32o4i", tag::OIhw4i32o4i}, - {"OIhw4i64o4i", tag::OIhw4i64o4i}, - {"OIhw4i4o", tag::OIhw4i4o}, - {"OIhw4o4i", tag::OIhw4o4i}, - {"Oihw4o", tag::Oihw4o}, - {"OIhw8i16o2i", tag::OIhw8i16o2i}, - {"OIhw8i32o2i", tag::OIhw8i32o2i}, - {"OIhw8i64o2i", tag::OIhw8i64o2i}, - {"OIhw8i8o", tag::OIhw8i8o}, - {"OIhw8o16i2o", tag::OIhw8o16i2o}, - {"OIhw8o8i", tag::OIhw8o8i}, - {"OIhw8o4i", tag::OIhw8o4i}, - {"OIhw2i8o4i", tag::OIhw2i8o4i}, - {"IOdhw16o16i", tag::IOdhw16o16i}, - {"Odhwi16o", tag::Odhwi16o}, - {"OdhwI16o2i", tag::OdhwI16o2i}, - {"Odhwi4o", tag::Odhwi4o}, - {"Odhwi8o", tag::Odhwi8o}, - {"OIdhw16i16o", tag::OIdhw16i16o}, - {"OIdhw16i32o", tag::OIdhw16i32o}, - {"OIdhw16i64o", tag::OIdhw16i64o}, - {"OIdhw16o16i", tag::OIdhw16o16i}, - {"Oidhw16o", tag::Oidhw16o}, - {"OIdhw4i4o", tag::OIdhw4i4o}, - {"OIdhw4o4i", tag::OIdhw4o4i}, - {"Oidhw4o", tag::Oidhw4o}, - {"OIdhw8i16o2i", tag::OIdhw8i16o2i}, - {"OIdhw8i32o2i", tag::OIdhw8i32o2i}, - {"OIdhw8i64o2i", tag::OIdhw8i64o2i}, - {"OIdhw4i16o4i", tag::OIdhw4i16o4i}, - {"OIdhw16i16o4i", tag::OIdhw16i16o4i}, - {"OIdhw16i32o4i", tag::OIdhw16i32o4i}, - {"OIdhw16i48o4i", tag::OIdhw16i48o4i}, - {"OIdhw16i64o4i", tag::OIdhw16i64o4i}, - {"OIdhw16i16o2i", tag::OIdhw16i16o2i}, - {"OIdhw16i32o2i", tag::OIdhw16i32o2i}, - {"OIdhw16i48o2i", tag::OIdhw16i48o2i}, - {"OIdhw16i64o2i", tag::OIdhw16i64o2i}, - {"OIdhw4i32o4i", tag::OIdhw4i32o4i}, - {"OIdhw4i64o4i", tag::OIdhw4i64o4i}, - {"OIdhw2i8o4i", tag::OIdhw2i8o4i}, - {"OIdhw8i8o", tag::OIdhw8i8o}, - {"OIdhw8o8i", tag::OIdhw8o8i}, - {"OIdhw8o4i", tag::OIdhw8o4i}, - {"gIOw16o16i", tag::gIOw16o16i}, - {"gOIw16i16o", tag::gOIw16i16o}, - {"gOIw16o16i", tag::gOIw16o16i}, - {"gOiw16o", tag::gOiw16o}, - {"gOIw4i16o4i", tag::gOIw4i16o4i}, - {"gOIw2i8o4i", tag::gOIw2i8o4i}, - {"gOIw4i4o", tag::gOIw4i4o}, - {"gOIw4o4i", tag::gOIw4o4i}, - {"gOiw4o", tag::gOiw4o}, - {"gOIw8i16o2i", tag::gOIw8i16o2i}, - {"gOIw8i8o", tag::gOIw8i8o}, - {"gOIw8o16i2o", tag::gOIw8o16i2o}, - {"gOIw8o8i", tag::gOIw8o8i}, - {"gOIw8o4i", tag::gOIw8o4i}, - {"gOIw16i16o4i", tag::gOIw16i16o4i}, - {"gOIw16i16o2i", tag::gOIw16i16o2i}, - {"gOIw16o16i2o", tag::gOIw16o16i2o}, - {"gOwi16o", tag::gOwi16o}, - {"gOwI16o2i", tag::gOwI16o2i}, - {"gOwi4o", tag::gOwi4o}, - {"gOwi8o", tag::gOwi8o}, - {"Goiw8g", tag::Goiw8g}, - {"Goiw16g", tag::Goiw16g}, - {"gIOhw16o16i", tag::gIOhw16o16i}, - {"gOhwi16o", tag::gOhwi16o}, - {"gOhwI16o2i", tag::gOhwI16o2i}, - {"gOhwi4o", tag::gOhwi4o}, - {"gOhwi8o", tag::gOhwi8o}, - {"Goihw16g", tag::Goihw16g}, - {"gOIhw16i16o", tag::gOIhw16i16o}, - {"gOIhw16o16i", tag::gOIhw16o16i}, - {"gOihw16o", tag::gOihw16o}, - {"gOIhw4i16o4i", tag::gOIhw4i16o4i}, - {"gOIhw2i8o4i", tag::gOIhw2i8o4i}, - {"gOIhw4i4o", tag::gOIhw4i4o}, - {"gOIhw4o4i", tag::gOIhw4o4i}, - {"gOihw4o", tag::gOihw4o}, - {"Goihw8g", tag::Goihw8g}, - {"gOIhw8i16o2i", tag::gOIhw8i16o2i}, - {"gOIhw8i8o", tag::gOIhw8i8o}, - {"gOIhw8o16i2o", tag::gOIhw8o16i2o}, - {"OIw4o8i8o4i", tag::OIw4o8i8o4i}, - {"OIdhw4o8i8o4i", tag::OIdhw4o8i8o4i}, - {"OIhw4o8i8o4i", tag::OIhw4o8i8o4i}, - {"OIhw2o8i8o2i", tag::OIhw2o8i8o2i}, - {"gOIw4o8i8o4i", tag::gOIw4o8i8o4i}, - {"gOIdhw4o8i8o4i", tag::gOIdhw4o8i8o4i}, - {"gOIhw4o8i8o4i", tag::gOIhw4o8i8o4i}, - {"gOIhw2o8i8o2i", tag::gOIhw2o8i8o2i}, - {"OIhw16i16o4i", tag::OIhw16i16o4i}, - {"OIhw16i32o4i", tag::OIhw16i32o4i}, - {"OIhw16i48o4i", tag::OIhw16i48o4i}, - {"OIhw16i64o4i", tag::OIhw16i64o4i}, - {"OIhw16i16o2i", tag::OIhw16i16o2i}, - {"OIhw16i32o2i", tag::OIhw16i32o2i}, - {"OIhw16i48o2i", tag::OIhw16i48o2i}, - {"OIhw16i64o2i", tag::OIhw16i64o2i}, - {"OIhw16o16i2o", tag::OIhw16o16i2o}, - {"gOIhw16i16o4i", tag::gOIhw16i16o4i}, - {"gOIhw16i16o2i", tag::gOIhw16i16o2i}, - {"gOIhw16o16i2o", tag::gOIhw16o16i2o}, - {"gOIhw8o8i", tag::gOIhw8o8i}, - {"gOIhw8o4i", tag::gOIhw8o4i}, - {"gIOdhw16i16o", tag::gIOdhw16i16o}, - {"gIOdhw16o16i", tag::gIOdhw16o16i}, - {"gOdhwi16o", tag::gOdhwi16o}, - {"gOdhwI16o2i", tag::gOdhwI16o2i}, - {"gOdhwi4o", tag::gOdhwi4o}, - {"gOdhwi8o", tag::gOdhwi8o}, - {"gOIdhw16i16o", tag::gOIdhw16i16o}, - {"gOIdhw16o16i", tag::gOIdhw16o16i}, - {"gOidhw16o", tag::gOidhw16o}, - {"gOIdhw4i4o", tag::gOIdhw4i4o}, - {"gOIdhw4o4i", tag::gOIdhw4o4i}, - {"gOidhw4o", tag::gOidhw4o}, - {"gOIdhw8i16o2i", tag::gOIdhw8i16o2i}, - {"gOIdhw4i16o4i", tag::gOIdhw4i16o4i}, - {"gOIdhw16i16o4i", tag::gOIdhw16i16o4i}, - {"gOIdhw16i16o2i", tag::gOIdhw16i16o2i}, - {"gOIdhw2i8o4i", tag::gOIdhw2i8o4i}, - {"gOIdhw8i8o", tag::gOIdhw8i8o}, - {"gOIdhw8o8i", tag::gOIdhw8o8i}, - {"gOIdhw8o4i", tag::gOIdhw8o4i}, - {"gOIw2i4o2i", tag::gOIw2i4o2i}, - {"gOIhw2i4o2i", tag::gOIhw2i4o2i}, - {"gOIdhw2i4o2i", tag::gOIdhw2i4o2i}, - {"gOIw2o4i2o", tag::gOIw2o4i2o}, - {"gOIhw2o4i2o", tag::gOIhw2o4i2o}, - {"gOIdhw2o4i2o", tag::gOIdhw2o4i2o}, - {"gOIw4i8o2i", tag::gOIw4i8o2i}, - {"gOIhw4i8o2i", tag::gOIhw4i8o2i}, - {"gOIdhw4i8o2i", tag::gOIdhw4i8o2i}, - {"gOIw4o8i2o", tag::gOIw4o8i2o}, - {"gOIhw4o8i2o", tag::gOIhw4o8i2o}, - {"gOIdhw4o8i2o", tag::gOIdhw4o8i2o}, - {"ldOi32o", tag::ldOi32o}, - {"ldOI32o4i", tag::ldOI32o4i}, - {"ldgOi32o", tag::ldgOi32o}, - {"ldgOI32o2i", tag::ldgOI32o2i}, - {"ldgOI32o4i", tag::ldgOI32o4i}, - {"OwI16o4i", tag::OwI16o4i}, - {"OhwI16o4i", tag::OhwI16o4i}, - {"gOwI16o4i", tag::gOwI16o4i}, - {"gOhwI16o4i", tag::gOhwI16o4i}, - {"OdhwI16o4i", tag::OdhwI16o4i}, - {"gOdhwI16o4i", tag::gOdhwI16o4i}, - {"Owi32o", tag::Owi32o}, - {"OwI32o2i", tag::OwI32o2i}, - {"OwI32o4i", tag::OwI32o4i}, - {"Owi48o", tag::Owi48o}, - {"OwI48o2i", tag::OwI48o2i}, - {"OwI48o4i", tag::OwI48o4i}, - {"Owi64o", tag::Owi64o}, - {"OwI64o2i", tag::OwI64o2i}, - {"OwI64o4i", tag::OwI64o4i}, - {"wIo2i", tag::wIo2i}, - {"wIo4i", tag::wIo4i}, - {"gOwi32o", tag::gOwi32o}, - {"gOwI32o2i", tag::gOwI32o2i}, - {"gOwI32o4i", tag::gOwI32o4i}, - {"gOwi48o", tag::gOwi48o}, - {"gOwI48o2i", tag::gOwI48o2i}, - {"gOwI48o4i", tag::gOwI48o4i}, - {"gOwi64o", tag::gOwi64o}, - {"gOwI64o2i", tag::gOwI64o2i}, - {"gOwI64o4i", tag::gOwI64o4i}, - {"gwio", tag::gwio}, - {"gwIo2i", tag::gwIo2i}, - {"gwIo4i", tag::gwIo4i}, - {"OhwI32o", tag::OhwI32o}, - {"OhwI32o2i", tag::OhwI32o2i}, - {"OhwI32o4i", tag::OhwI32o4i}, - {"Ohwi48o", tag::Ohwi48o}, - {"OhwI48o2i", tag::OhwI48o2i}, - {"OhwI48o4i", tag::OhwI48o4i}, - {"Ohwi64o", tag::Ohwi64o}, - {"OhwI64o2i", tag::OhwI64o2i}, - {"OhwI64o4i", tag::OhwI64o4i}, - {"hwIo2i", tag::hwIo2i}, - {"hwIo4i", tag::hwIo4i}, - {"gOhwI32o", tag::gOhwI32o}, - {"gOhwI32o2i", tag::gOhwI32o2i}, - {"gOhwI32o4i", tag::gOhwI32o4i}, - {"gOhwi48o", tag::gOhwi48o}, - {"gOhwI48o2i", tag::gOhwI48o2i}, - {"gOhwI48o4i", tag::gOhwI48o4i}, - {"gOhwi64o", tag::gOhwi64o}, - {"gOhwI64o2i", tag::gOhwI64o2i}, - {"gOhwI64o4i", tag::gOhwI64o4i}, - {"ghwio", tag::ghwio}, - {"ghwIo2i", tag::ghwIo2i}, - {"ghwIo4i", tag::ghwIo4i}, - {"Odhwi32o", tag::Odhwi32o}, - {"OdhwI32o2i", tag::OdhwI32o2i}, - {"OdhwI32o4i", tag::OdhwI32o4i}, - {"Odhwi48o", tag::Odhwi48o}, - {"OdhwI48o2i", tag::OdhwI48o2i}, - {"OdhwI48o4i", tag::OdhwI48o4i}, - {"Odhwi64o", tag::Odhwi64o}, - {"OdhwI64o2i", tag::OdhwI64o2i}, - {"OdhwI64o4i", tag::OdhwI64o4i}, - {"dhwIo2i", tag::dhwIo2i}, - {"dhwIo4i", tag::dhwIo4i}, - {"gOdhwi32o", tag::gOdhwi32o}, - {"gOdhwI32o2i", tag::gOdhwI32o2i}, - {"gOdhwI32o4i", tag::gOdhwI32o4i}, - {"gOdhwi48o", tag::gOdhwi48o}, - {"gOdhwI48o2i", tag::gOdhwI48o2i}, - {"gOdhwI48o4i", tag::gOdhwI48o4i}, - {"gOdhwi64o", tag::gOdhwi64o}, - {"gOdhwI64o2i", tag::gOdhwI64o2i}, - {"gOdhwI64o4i", tag::gOdhwI64o4i}, - {"gdhwio", tag::gdhwio}, - {"gdhwIo2i", tag::gdhwIo2i}, - {"gdhwIo4i", tag::gdhwIo4i}, - {"ldIo32i", tag::ldIo32i}, - {"ldgIo32i", tag::ldgIo32i}, - {"ldgIO32i2o", tag::ldgIO32i2o}, - {"nCdhw32c", tag::nCdhw32c}, - {"nChw32c", tag::nChw32c}, - {"nCw32c", tag::nCw32c}, - {"NCw32n16c", tag::NCw32n16c}, - {"NChw32n16c", tag::NChw32n16c}, - {"NCdhw32n16c", tag::NCdhw32n16c}, - {"NCw32n32c", tag::NCw32n32c}, - {"OI16i16o4i", tag::OI16i16o4i}, - {"IOw8o16i2o", tag::IOw8o16i2o}, - {"IOhw8o16i2o", tag::IOhw8o16i2o}, - {"Owhi16o", tag::Owhi16o}, - {"OIdhw8o16i2o", tag::OIdhw8o16i2o}, - {"IOdhw8o16i2o", tag::IOdhw8o16i2o}, - {"Goiw4g", tag::Goiw4g}, - {"gIOw8o16i2o", tag::gIOw8o16i2o}, - {"Goiw32g", tag::Goiw32g}, - {"Goihw4g", tag::Goihw4g}, - {"gIOhw8o16i2o", tag::gIOhw8o16i2o}, - {"Goihw32g", tag::Goihw32g}, - {"gOwhi16o", tag::gOwhi16o}, - {"IOw4i8o8i4o", tag::IOw4i8o8i4o}, - {"IOhw4i8o8i4o", tag::IOhw4i8o8i4o}, - {"IOdhw4i8o8i4o", tag::IOdhw4i8o8i4o}, - {"gIOw4i8o8i4o", tag::gIOw4i8o8i4o}, - {"gIOhw4i8o8i4o", tag::gIOhw4i8o8i4o}, - {"gIOdhw4i8o8i4o", tag::gIOdhw4i8o8i4o}, - {"gOIdhw8o16i2o", tag::gOIdhw8o16i2o}, - {"gIOdhw8o16i2o", tag::gIOdhw8o16i2o}, - {"Goidhw32g", tag::Goidhw32g}, - {"OI16i32o4i", tag::OI16i32o4i}, - {"OI16i48o4i", tag::OI16i48o4i}, - {"OI16i64o4i", tag::OI16i64o4i}, - {"OI16i16o2i", tag::OI16i16o2i}, - {"OI16i32o2i", tag::OI16i32o2i}, - {"OI16i48o2i", tag::OI16i48o2i}, - {"OI16i64o2i", tag::OI16i64o2i}, - {"OwI16i16o2i", tag::OwI16i16o2i}, - {"gOwI16i16o2i", tag::gOwI16i16o2i}, - {"OhwI16i16o2i", tag::OhwI16i16o2i}, - {"gOhwI16i16o2i", tag::gOhwI16i16o2i}, - {"OdhwI16i16o2i", tag::OdhwI16i16o2i}, - {"gOdhwI16i16o2i", tag::gOdhwI16i16o2i}, - {"OwI16i16o4i", tag::OwI16i16o4i}, - {"gOwI16i16o4i", tag::gOwI16i16o4i}, - {"OhwI16i16o4i", tag::OhwI16i16o4i}, - {"gOhwI16i16o4i", tag::gOhwI16i16o4i}, - {"OdhwI16i16o4i", tag::OdhwI16i16o4i}, - {"gOdhwI16i16o4i", tag::gOdhwI16i16o4i}, - {"OwI16i32o2i", tag::OwI16i32o2i}, - {"OwI16i32o4i", tag::OwI16i32o4i}, - {"OwI16i48o2i", tag::OwI16i48o2i}, - {"OwI16i48o4i", tag::OwI16i48o4i}, - {"OwI16i64o2i", tag::OwI16i64o2i}, - {"OwI16i64o4i", tag::OwI16i64o4i}, - {"gOwI16i32o2i", tag::gOwI16i32o2i}, - {"gOwI16i32o4i", tag::gOwI16i32o4i}, - {"gOwI16i48o2i", tag::gOwI16i48o2i}, - {"gOwI16i48o4i", tag::gOwI16i48o4i}, - {"gOwI16i64o2i", tag::gOwI16i64o2i}, - {"gOwI16i64o4i", tag::gOwI16i64o4i}, - {"OhwI16i32o2i", tag::OhwI16i32o2i}, - {"OhwI16i32o4i", tag::OhwI16i32o4i}, - {"OhwI16i48o2i", tag::OhwI16i48o2i}, - {"OhwI16i48o4i", tag::OhwI16i48o4i}, - {"OhwI16i64o2i", tag::OhwI16i64o2i}, - {"OhwI16i64o4i", tag::OhwI16i64o4i}, - {"gOhwI16i32o2i", tag::gOhwI16i32o2i}, - {"gOhwI16i32o4i", tag::gOhwI16i32o4i}, - {"gOhwI16i48o2i", tag::gOhwI16i48o2i}, - {"gOhwI16i48o4i", tag::gOhwI16i48o4i}, - {"gOhwI16i64o2i", tag::gOhwI16i64o2i}, - {"gOhwI16i64o4i", tag::gOhwI16i64o4i}, - {"OdhwI16i32o2i", tag::OdhwI16i32o2i}, - {"OdhwI16i32o4i", tag::OdhwI16i32o4i}, - {"OdhwI16i48o2i", tag::OdhwI16i48o2i}, - {"OdhwI16i48o4i", tag::OdhwI16i48o4i}, - {"OdhwI16i64o2i", tag::OdhwI16i64o2i}, - {"OdhwI16i64o4i", tag::OdhwI16i64o4i}, - {"gOdhwI16i32o2i", tag::gOdhwI16i32o2i}, - {"gOdhwI16i32o4i", tag::gOdhwI16i32o4i}, - {"gOdhwI16i48o2i", tag::gOdhwI16i48o2i}, - {"gOdhwI16i48o4i", tag::gOdhwI16i48o4i}, - {"gOdhwI16i64o2i", tag::gOdhwI16i64o2i}, - {"gOdhwI16i64o4i", tag::gOdhwI16i64o4i}, - {"hwioG16g", tag::hwioG16g}, - {"NCdhw40n32c", tag::NCdhw40n32c}, - {"NChw40n32c", tag::NChw40n32c}, - {"NCw40n32c", tag::NCw40n32c}, - {"OIdhw4o8i8o2i", tag::OIdhw4o8i8o2i}, - {"OIhw4o8i8o2i", tag::OIhw4o8i8o2i}, - {"OIw4o8i8o2i", tag::OIw4o8i8o2i}, - {"gOIdhw4o8i8o2i", tag::gOIdhw4o8i8o2i}, - {"gOIhw4o8i8o2i", tag::gOIhw4o8i8o2i}, - {"gOIw4o8i8o2i", tag::gOIw4o8i8o2i}, - {"IOdhw4i8o8i2o", tag::IOdhw4i8o8i2o}, - {"IOhw4i8o8i2o", tag::IOhw4i8o8i2o}, - {"IOw4i8o8i2o", tag::IOw4i8o8i2o}, - {"gIOdhw4i8o8i2o", tag::gIOdhw4i8o8i2o}, - {"gIOhw4i8o8i2o", tag::gIOhw4i8o8i2o}, - {"gIOw4i8o8i2o", tag::gIOw4i8o8i2o}, - {"NCdhw40n16c", tag::NCdhw40n16c}, - {"NCw40n16c", tag::NCw40n16c}, - {"NChw40n16c", tag::NChw40n16c}, - {"NCw2c32n8c", tag::NCw2c32n8c}, - {"NChw2c32n8c", tag::NChw2c32n8c}, - {"NCdhw2c32n8c", tag::NCdhw2c32n8c}, - {"OIw2i8o16i4o", tag::OIw2i8o16i4o}, - {"OIhw2i8o16i4o", tag::OIhw2i8o16i4o}, - {"OIdhw2i8o16i4o", tag::OIdhw2i8o16i4o}, - {"OIw2o8i16o4i", tag::OIw2o8i16o4i}, - {"OIw2o8i16o2i", tag::OIw2o8i16o2i}, - {"IOw2i8o16i4o", tag::IOw2i8o16i4o}, - {"IOw2i8o16i2o", tag::IOw2i8o16i2o}, - {"OIhw2o8i16o4i", tag::OIhw2o8i16o4i}, - {"OIhw2o8i16o2i", tag::OIhw2o8i16o2i}, - {"IOhw2i8o16i4o", tag::IOhw2i8o16i4o}, - {"IOhw2i8o16i2o", tag::IOhw2i8o16i2o}, - {"OIdhw2o8i16o4i", tag::OIdhw2o8i16o4i}, - {"OIdhw2o8i16o2i", tag::OIdhw2o8i16o2i}, - {"IOdhw2i8o16i4o", tag::IOdhw2i8o16i4o}, - {"IOdhw2i8o16i2o", tag::IOdhw2i8o16i2o}, - {"gOIw2o8i16o2i", tag::gOIw2o8i16o2i}, - {"gIOw2i8o16i2o", tag::gIOw2i8o16i2o}, - {"gIOhw2i8o16i2o", tag::gIOhw2i8o16i2o}, - {"gIOdhw2i8o16i2o", tag::gIOdhw2i8o16i2o}, - {"gOIhw2o8i16o2i", tag::gOIhw2o8i16o2i}, - {"gOIdhw2o8i16o2i", tag::gOIdhw2o8i16o2i}, - {"gOIw2o8i16o4i", tag::gOIw2o8i16o4i}, - {"gOIhw2o8i16o4i", tag::gOIhw2o8i16o4i}}; - std::string key = ""; - for (const auto& c : layout) { - if (std::isalpha(c, std::locale("C"))) { - char lower_c = std::tolower(c); - if (std::isupper(c) && (layout.find(lower_c) != std::string::npos)) { - key.push_back(c); - } else { - key.push_back(lower_c); - } - } else if (std::isdigit(c)) { - key.push_back(c); - } else { - LOG(FATAL) << "invalid char '" << c << "' in " << layout << std::endl; - } - } - if (str2tag.count(key) == 0) { - LOG(WARNING) << "convert unregistered layout '" << key << "' to tag::any"; - return tag::any; + /* Override GetFunction to reimplement Run method */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { + if (this->symbol_name_ == name) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK(this->initialized_) << "The module has not been initialized"; + + ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) + << "Found mismatch in the number of provided data entries and required."; + + Run(args); + }); } else { - return str2tag.at(key); + return JSONRuntimeBase::GetFunction(name, sptr_to_self); + } + } + + /* Same as makeInitDataProvider but in case of InputOutput return real DLTensor */ + TensorRegistry::DLTensorProvider makeIODataProvider(const TVMArgs& args) const { + auto extract_dl_tensor = [](const TVMArgValue& val) -> const DLTensor* { + ICHECK(val.type_code() == kTVMNDArrayHandle || val.type_code() == kTVMDLTensorHandle) + << "Expect NDArray or DLTensor"; + return val.IsObjectRef() ? val.operator NDArray().operator->() + : val.operator DLTensor*(); + }; + + std::map io_map; // eid to dl tensor map + for (size_t i = 0; i < run_arg_eid_.size(); i++) { + io_map[run_arg_eid_[i]] = extract_dl_tensor(args[i]); } + + // lambda with captured IO data handlers + return [io_map](uint32_t eid) -> const DLTensor* { return io_map.at(eid); }; } - std::map elt_name2algo{ + private: + const std::map elt_name2algo{ {"abs", dnnl::algorithm::eltwise_abs}, {"exp", dnnl::algorithm::eltwise_exp}, {"log", dnnl::algorithm::eltwise_log}, @@ -626,64 +161,14 @@ class DNNLJSONRuntime : public JSONRuntimeBase { return std::regex_match(op_name, bias_add_pat) ? true : false; } - dnnl::memory::dims TransDims2Plain(dnnl::memory::dims input_dims, std::string layout) { - std::vector axis = { - 'N', 'C', 'O', 'I', 'D', 'H', 'W', - }; - dnnl::memory::dims out_dims; - std::string::iterator t = layout.begin(); - // Remove numbers in layout string to match the size of input_dims - while (t != layout.end()) { - if (*t >= '0' && *t <= '9') { - layout.erase(t); - } else { - t++; - } - } - // Push the correct shapes of each axis into the output_dims - for (auto a : axis) { - if (layout.find(a) != std::string::npos) { - dnnl::memory::dim shape = input_dims[layout.find(a)]; - char lower_a = std::tolower(a); - for (size_t i = 0; i < layout.size(); ++i) { - if (lower_a == layout[i]) { - shape *= input_dims[i]; - } - } - out_dims.push_back(shape); - } - } - // Multiply O and I with G, respectively - if (layout.find("G") != std::string::npos) { - dnnl::memory::dim G = 1; - if (layout.find("g") != std::string::npos) { - G = input_dims[layout.find("g")] * input_dims[layout.find("G")]; - } else { - G = input_dims[layout.find("G")]; - } - out_dims[0] *= G; - out_dims[1] *= G; - } - return out_dims; - } - - dnnl::memory::dims TransformStr2Dims(std::vector strs, bool dilates = false) { - dnnl::memory::dims out_dims; - if (dilates) { - std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims), - [](const std::string& str) { return std::stoi(str) - 1; }); - } else { - std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims), - [](const std::string& str) { return std::stoi(str); }); - } - return out_dims; - } - // Build up the engine based on the input graph. void BuildEngine() { engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0); stream_ = dnnl::stream(engine_); + std::set io_eid_set(run_arg_eid_.begin(), run_arg_eid_.end()); + tensor_registry_ = TensorRegistry(engine_, io_eid_set); + std::regex conv_pat(".*conv[1-3]d.*"); std::regex deconv_pat(".*deconv[1-3]d.*"); std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*"); @@ -725,562 +210,471 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } } - // Bind a JSON graph node entry to a DNNL memory. - dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory::desc mem_desc, - size_t offset = 0) { - auto eid = EntryID(entry); - if (entry_out_mem_.count(eid) == 0) { - return BindDNNLMemory(entry, dnnl::memory(mem_desc, engine_), offset); - } - return entry_out_mem_[eid].first; - } - - // Bind a JSON graph node entry to a given DNNL memory. - dnnl::memory BindDNNLMemory(const JSONGraphNodeEntry& entry, dnnl::memory mem, - size_t offset = 0) { - auto eid = EntryID(entry); - // Since the DNNL memory has been created before calling this function, we assume the entry - // has not yet been bound to the other DNNL memory; otherwise it may have memory leak. - ICHECK_EQ(entry_out_mem_.count(eid), 0); - - entry_out_mem_[eid] = {mem, offset}; - return entry_out_mem_[eid].first; - } - void Convolution(const size_t& nid) { auto node = nodes_[nid]; auto op_name = node.GetOpName(); dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. - auto data_entry = node.GetInputs()[0]; - auto weight_entry = node.GetInputs()[1]; - JSONGraphNodeEntry out_entry(nid, 0); - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; - dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; - dnnl::memory::dim channels = - node.GetAttr>("channels")[0] != "" - ? std::stoi(node.GetAttr>("channels")[0]) - : out_shape[1]; - std::vector str_strides = node.GetAttr>("strides"); - std::vector str_dilates = node.GetAttr>("dilation"); - std::vector str_padding = node.GetAttr>("padding"); - std::vector str_padding_l(str_padding.begin(), - str_padding.begin() + str_padding.size() / 2); - std::vector str_padding_r(str_padding.end() - str_padding.size() / 2, - str_padding.end()); - dnnl::memory::dim groups = std::stoi(node.GetAttr>("groups")[0]); - std::string data_layout = node.GetAttr>("data_layout")[0]; - std::string kernel_layout = node.GetAttr>("kernel_layout")[0]; - - // Memory shapes. - dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout); - dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout); - dnnl::memory::dims bias_dims = {channels}; - dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides); - dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true); - dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l); - dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r); - dnnl::memory::dims dst_dims = src_dims; - dst_dims[1] = channels; - weights_dims_[0] = channels; - weights_dims_[1] = src_dims[1]; - for (size_t i = 2; i < src_dims.size(); i++) { - dnnl::memory::dim K = weights_dims_[i]; - dnnl::memory::dim S = strides_dims[i - 2]; - dnnl::memory::dim D = dilates_dims[i - 2]; - dnnl::memory::dim PL = padding_dims_l[i - 2]; - dnnl::memory::dim PR = padding_dims_r[i - 2]; - dnnl::memory::dim DK = 1 + (K - 1) * (D + 1); - dst_dims[i] = (src_dims[i] - DK + PL + PR) / S + 1; + auto src_tr = GetInput(nid, 0); + auto wgh_tr = GetInput(nid, 1); + auto dst_tr = GetOutput(nid, 0); + auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1); + auto strides = GetNodeAttr>(node, "strides"); + auto dilates = GetNodeAttr>(node, "dilation"); + auto padding = GetNodeAttr>(node, "padding"); + std::vector padding_l(padding.begin(), padding.begin() + padding.size() / 2); + std::vector padding_r(padding.begin() + padding.size() / 2, padding.end()); + auto groups = GetNodeAttr(node, "groups"); + auto src_layout = GetNodeAttr(node, "data_layout"); + auto dst_layout = GetNodeAttr(node, "out_layout"); + auto wgh_layout = GetNodeAttr(node, "kernel_layout"); + + // dst_layout == "" means to use data_layout + if (dst_layout.empty()) dst_layout = src_layout; + + // Minus one for DNNL representation. No dilation for DNNL is 0, for relay is 1. + for (auto& d : dilates) d--; + + // Take into account provided layout strings + src_tr = src_tr.TreatAs(src_layout); + dst_tr = dst_tr.TreatAs(dst_layout); + wgh_tr = wgh_tr.TreatAs(wgh_layout); + + // Should support G mixed with O. Like { G*O, I, H, W } + // Use { G, O, I, H, W } weight format even if groups == 1 + if (wgh_layout.find("G") == std::string::npos) { + auto w_dims = wgh_tr.dims(); + w_dims[0] /= groups; + w_dims.insert(w_dims.begin(), groups); + wgh_tr = wgh_tr.Reshape(w_dims); } - dnnl::memory::dims weights_dims = weights_dims_; - if (groups > 1) { - weights_dims = {groups, channels / groups, src_dims[1] / groups}; - weights_dims.insert(weights_dims.end(), weights_dims_.begin() + 2, weights_dims_.end()); - if (kernel_layout == "OIHW") { - kernel_layout.insert(0, "G"); - } + // Assumption that bias is correct and can be squeezed to 1D + bias_tr = bias_tr.Reshape({dst_tr.dims()[1]}); + + // TODO(@apeskov): This is WA. In case of padded blocked tensor format we do not know original + // shapes. Example tensor {1, 10, 224, 224} with layout "NCNH8c" will lead to tensor + // {1, 2, 224, 224, 8}. Identically as for shapes {1, 11, 224, 224} or {1, 15, 224, 224}. + // + // Let's try to compensate it for weight tensor. Weight IC should match with source IC. + // Example src: [1, 3, 224, 224] with layout NCHW + // wgh: [16, 3, 3, 3] with layout OIHW2i8o -> [2, 2, 3, 3, 2, 8] + if (wgh_tr.dims()[2] != src_tr.dims()[1] / groups) { + auto wgh_croped_dims = wgh_tr.dims(); + wgh_croped_dims[2] = src_tr.dims()[1]; + auto zero_offset = dnnl::memory::dims(wgh_tr.dims().size(), 0); + wgh_tr = wgh_tr.Crop(wgh_croped_dims, zero_offset); } - // Memory descriptions. - auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); - auto conv_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(data_layout)); - auto conv_weights_md = dnnl::memory::desc(weights_dims, dtype, layout2tag(kernel_layout)); - auto conv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::any); - auto conv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any); - // Conv description. - auto conv_desc = - has_bias ? dnnl::convolution_forward::desc( - dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, - conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, - dilates_dims, padding_dims_l, padding_dims_r) - : dnnl::convolution_forward::desc(dnnl::prop_kind::forward_inference, - dnnl::algorithm::convolution_direct, conv_src_md, - conv_weights_md, conv_dst_md, strides_dims, - dilates_dims, padding_dims_l, padding_dims_r); + auto conv_desc = dnnl::convolution_forward::desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, + src_tr.LayoutAny().desc(), wgh_tr.LayoutAny().desc(), bias_tr.LayoutAny().desc(), + dst_tr.LayoutAny().desc(), strides, dilates, padding_l, padding_r); // Enable elementwise post-ops. auto conv_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc, attr, engine_); - // Push to the network. - auto conv = dnnl::convolution_forward(conv_prim_desc); - net_.push_back(conv); - - // Data memory. - auto conv_src_memory = BindDNNLMemory(data_entry, conv_src_md); + src_tr = src_tr.RequestLayout(conv_prim_desc.src_desc()); + wgh_tr = wgh_tr.RequestLayout(conv_prim_desc.weights_desc()); + dst_tr = dst_tr.RequestLayout(conv_prim_desc.dst_desc()); + bias_tr = bias_tr.RequestLayout(conv_prim_desc.bias_desc()); - // Weight memory. - auto conv_weights_memory = BindDNNLMemory(weight_entry, conv_prim_desc.weights_desc()); + auto scratchpad_tr = TensorRequisite::AsIs(conv_prim_desc.scratchpad_desc()); - // Output memory. - auto conv_dst_memory = BindDNNLMemory(out_entry, conv_prim_desc.dst_desc()); - - // Bias memory. - auto conv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x}, engine_); - if (has_bias) { - auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, conv_bias_memory); - - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_BIAS, conv_bias_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); - } else { - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); - } + Submit(dnnl::convolution_forward(conv_prim_desc), {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_WEIGHTS, wgh_tr}, + {DNNL_ARG_BIAS, bias_tr}, + {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, + {DNNL_ARG_DST, dst_tr}}); } void Deconvolution(const size_t& nid) { auto node = nodes_[nid]; auto op_name = node.GetOpName(); dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. - auto data_entry = node.GetInputs()[0]; - auto weight_entry = node.GetInputs()[1]; - JSONGraphNodeEntry out_entry(nid, 0); - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; - dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; - dnnl::memory::dim channels = - node.GetAttr>("channels")[0] != "" - ? std::stoi(node.GetAttr>("channels")[0]) - : out_shape[1]; - std::vector str_strides = node.GetAttr>("strides"); - std::vector str_dilates = node.GetAttr>("dilation"); - std::vector str_padding = node.GetAttr>("padding"); - std::vector str_padding_l(str_padding.begin(), - str_padding.begin() + str_padding.size() / 2); - std::vector str_padding_r(str_padding.end() - str_padding.size() / 2, - str_padding.end()); - std::vector str_out_padding = - node.GetAttr>("output_padding"); - dnnl::memory::dim groups = std::stoi(node.GetAttr>("groups")[0]); - std::string data_layout = node.GetAttr>("data_layout")[0]; - std::string kernel_layout = node.GetAttr>("kernel_layout")[0]; - - // Memory shapes. - dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout); - dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout); - // legalize shape IOHW with layout OIHW - if (weights_dims_[0] == src_dims[1] && weights_dims_[1] == channels) { - std::swap(weights_dims_[0], weights_dims_[1]); - if (kernel_layout.find("OI") == 0) { - kernel_layout.replace(kernel_layout.find("OI"), 2, "IO"); - } - } - weights_dims_[0] = channels; - weights_dims_[1] = src_dims[1]; - dnnl::memory::dims bias_dims = {channels}; - dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides); - dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true); - dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l); - dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r); - dnnl::memory::dims out_padding = TransformStr2Dims(str_out_padding); - dnnl::memory::dims dst_dims = src_dims; - dst_dims[1] = channels; - for (size_t i = 2; i < src_dims.size(); i++) { - dnnl::memory::dim K = weights_dims_[i]; - dnnl::memory::dim S = strides_dims[i - 2]; - dnnl::memory::dim D = dilates_dims[i - 2]; - dnnl::memory::dim PL = padding_dims_l[i - 2]; - dnnl::memory::dim PR = padding_dims_r[i - 2]; - dnnl::memory::dim OP = out_padding[i - 2]; - dnnl::memory::dim DK = 1 + (K - 1) * (D + 1); - dst_dims[i] = S * (src_dims[i] - 1) + DK - PL - PR + OP; + auto src_tr = GetInput(nid, 0); + auto wgh_tr = GetInput(nid, 1); + auto dst_tr = GetOutput(nid, 0); + auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1); + + auto strides = GetNodeAttr>(node, "strides"); + auto dilates = GetNodeAttr>(node, "dilation"); + auto padding = GetNodeAttr>(node, "padding"); + std::vector padding_l(padding.begin(), padding.begin() + padding.size() / 2); + std::vector padding_r(padding.begin() + padding.size() / 2, padding.end()); + auto groups = GetNodeAttr(node, "groups"); + auto src_layout = GetNodeAttr(node, "data_layout"); + auto dst_layout = GetNodeAttr(node, "out_layout"); + auto wgh_layout = GetNodeAttr(node, "kernel_layout"); + + // dst_layout == "" means to use data_layout + if (dst_layout.empty()) dst_layout = src_layout; + + // Minus one for DNNL representation. No dilation for DNNL is 0, for relay is 1. + for (auto& d : dilates) d--; + + // TODO(@apeskov): WA. conv3dTranspose uses wrong layout specifier. IO instead of OI. + auto wgh_logic_layout = TensorRequisite::DefaultLogicLayoutFor(wgh_layout); + if (wgh_logic_layout == "OIDHW") wgh_logic_layout = "IODHW"; + if (wgh_logic_layout == "GOIDHW") wgh_logic_layout = "GIODHW"; + + // Take into account provided layout strings + src_tr = src_tr.TreatAs(src_layout); + dst_tr = dst_tr.TreatAs(dst_layout); + wgh_tr = wgh_tr.TreatAs(wgh_layout, wgh_logic_layout); + + // Should support G mixed with O. Like { G*O, I, H, W } + if (wgh_layout.find("G") == std::string::npos) { + auto w_dims = wgh_tr.dims(); + w_dims[0] /= groups; + w_dims.insert(w_dims.begin(), groups); + wgh_tr = wgh_tr.Reshape(w_dims); } - dnnl::memory::dims weights_dims = weights_dims_; - if (groups > 1) { - weights_dims = {groups, channels / groups, src_dims[1] / groups}; - weights_dims.insert(weights_dims.end(), weights_dims_.begin() + 2, weights_dims_.end()); - } + // Assumption that bias is correct and can be squeezed to 1D + bias_tr = bias_tr.Reshape({dst_tr.dims()[1]}); - // Memory descriptions. - auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); - auto deconv_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(data_layout)); - auto deconv_weights_md = dnnl::memory::desc(weights_dims, dtype, layout2tag(kernel_layout)); - auto deconv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::x); - auto deconv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any); - - // Transposed covn2d description. - auto deconv_desc = - has_bias ? dnnl::deconvolution_forward::desc( - dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, - deconv_src_md, deconv_weights_md, deconv_bias_md, deconv_dst_md, - strides_dims, dilates_dims, padding_dims_l, padding_dims_r) - : dnnl::deconvolution_forward::desc( - dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, - deconv_src_md, deconv_weights_md, deconv_dst_md, strides_dims, dilates_dims, - padding_dims_l, padding_dims_r); + // Conv description. + auto deconv_desc = dnnl::deconvolution_forward::desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, + src_tr.LayoutAny().desc(), wgh_tr.LayoutAny().desc(), bias_tr.LayoutAny().desc(), + dst_tr.LayoutAny().desc(), strides, dilates, padding_l, padding_r); // Enable elementwise post-ops. auto deconv_prim_desc = dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_); - // Push to the network. - auto deconv = dnnl::deconvolution_forward(deconv_prim_desc); - net_.push_back(deconv); - - // Data memory. - auto deconv_src_memory = BindDNNLMemory(data_entry, deconv_src_md); - - // Weight memory. - auto deconv_weights_memory = BindDNNLMemory(weight_entry, deconv_prim_desc.weights_desc()); - - // Output memory. - auto deconv_dst_memory = BindDNNLMemory(out_entry, deconv_prim_desc.dst_desc()); + src_tr = src_tr.RequestLayout(deconv_prim_desc.src_desc()); + wgh_tr = wgh_tr.RequestLayout(deconv_prim_desc.weights_desc()); + dst_tr = dst_tr.RequestLayout(deconv_prim_desc.dst_desc()); + bias_tr = bias_tr.RequestLayout(deconv_prim_desc.bias_desc()); - // Bias memory. - auto deconv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x}, engine_); - if (has_bias) { - auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, deconv_bias_memory); + auto scratchpad_tr = TensorRequisite::AsIs(deconv_prim_desc.scratchpad_desc()); - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, deconv_src_memory}, - {DNNL_ARG_WEIGHTS, deconv_weights_memory}, - {DNNL_ARG_BIAS, deconv_bias_memory}, - {DNNL_ARG_DST, deconv_dst_memory}}); - } else { - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, deconv_src_memory}, - {DNNL_ARG_WEIGHTS, deconv_weights_memory}, - {DNNL_ARG_DST, deconv_dst_memory}}); - } + Submit(dnnl::deconvolution_forward(deconv_prim_desc), {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_WEIGHTS, wgh_tr}, + {DNNL_ARG_BIAS, bias_tr}, + {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, + {DNNL_ARG_DST, dst_tr}}); } void Dense(const size_t& nid) { auto node = nodes_[nid]; auto op_name = node.GetOpName(); dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. - auto data_entry = node.GetInputs()[0]; - auto weight_entry = node.GetInputs()[1]; - JSONGraphNodeEntry out_entry(nid, 0); - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; - dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; - dnnl::memory::dim OC = out_shape[1]; - - // Memory shapes. - dnnl::memory::dims data_dims = input_shape; - dnnl::memory::dims weight_dims = weight_shape; - dnnl::memory::dims bias_dims = {OC}; - dnnl::memory::dims out_dims = out_shape; - - // Memory descriptions. - auto dl_dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]; - auto dtype = dtype_dl2dnnl(dl_dtype); - auto data_md = dnnl::memory::desc({data_dims, dtype, tag::nc}); - auto weight_md = dnnl::memory::desc({weight_dims, dtype, tag::nc}); - auto bias_md = dnnl::memory::desc({bias_dims, dtype, tag::x}); - auto dst_md = dnnl::memory::desc({out_dims, dtype, tag::nc}); + auto src_tr = GetInput(nid, 0); + auto wgh_tr = GetInput(nid, 1); + auto dst_tr = GetOutput(nid, 0); + auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1); + + // Assumption that bias is correct and can be squeezed to 1D + bias_tr = bias_tr.Reshape({dst_tr.dims()[1]}); // Dense description. - auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md, - weight_md, bias_md, dst_md); + auto dense_desc = dnnl::inner_product_forward::desc( + dnnl::prop_kind::forward_inference, src_tr.LayoutAny().desc(), wgh_tr.LayoutAny().desc(), + bias_tr.LayoutAny().desc(), dst_tr.LayoutAny().desc()); // Enable elementwise post-ops. auto dense_prim_desc = dnnl::inner_product_forward::primitive_desc(dense_desc, attr, engine_); - auto dense = dnnl::inner_product_forward(dense_prim_desc); - net_.push_back(dense); + src_tr = src_tr.RequestLayout(dense_prim_desc.src_desc()); + wgh_tr = wgh_tr.RequestLayout(dense_prim_desc.weights_desc()); + dst_tr = dst_tr.RequestLayout(dense_prim_desc.dst_desc()); + bias_tr = bias_tr.RequestLayout(dense_prim_desc.bias_desc()); - // Memories. - auto data_memory = BindDNNLMemory(data_entry, data_md); - auto weight_memory = BindDNNLMemory(weight_entry, weight_md); + auto scratchpad_tr = TensorRequisite::AsIs(dense_prim_desc.scratchpad_desc()); - // Bias memory. - auto bias_memory = dnnl::memory(bias_md, engine_); - if (has_bias) { - auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, bias_memory); - } else { - float bias[OC] = {0}; - write_to_dnnl_memory(bias, bias_memory, OC * ((dl_dtype.bits + 7) / 8)); - } - - // Output memory. - auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc()); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, - {DNNL_ARG_WEIGHTS, weight_memory}, - {DNNL_ARG_BIAS, bias_memory}, - {DNNL_ARG_DST, dst_memory}}); + Submit(dnnl::inner_product_forward(dense_prim_desc), {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_WEIGHTS, wgh_tr}, + {DNNL_ARG_BIAS, bias_tr}, + {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, + {DNNL_ARG_DST, dst_tr}}); } void BatchNorm(const size_t& nid) { auto node = nodes_[nid]; - auto data_entry = node.GetInputs()[0]; - auto gamma_entry = node.GetInputs()[1]; - auto beta_entry = node.GetInputs()[2]; - auto mean_entry = node.GetInputs()[3]; - auto variance_entry = node.GetInputs()[4]; - dnnl::memory::dims data_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dim IC = data_shape[1]; - float epsilon = std::stof(node.GetAttr>("epsilon")[0]); + auto src_tr = GetInput(nid, 0); + auto gamma_tr = GetInput(nid, 1); + auto beta_tr = GetInput(nid, 2); + auto mean_tr = GetInput(nid, 3); + auto var_tr = GetInput(nid, 4); + auto dst_tr = GetOutput(nid, 0); - // Memory description. - auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); - dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype); + auto axis = GetNodeAttr(node, "axis"); + auto epsilon = GetNodeAttr(node, "epsilon"); + auto center = GetNodeAttr(node, "center"); + auto scale = GetNodeAttr(node, "scale"); + + ICHECK(axis == 1 && center && scale) << "Unimplemented BatchNorm case"; - // BN description. auto bn_desc = dnnl::batch_normalization_forward::desc( - dnnl::prop_kind::forward_inference, data_md, epsilon, + dnnl::prop_kind::forward_inference, src_tr.desc(), epsilon, dnnl::normalization_flags::use_global_stats | dnnl::normalization_flags::use_scale_shift); auto bn_prim_desc = dnnl::batch_normalization_forward::primitive_desc(bn_desc, engine_); - auto bn = dnnl::batch_normalization_forward(bn_prim_desc); - net_.push_back(bn); - - // Memories. - auto data_memory = BindDNNLMemory(data_entry, data_md); - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, data_md); - auto mean_memory = BindDNNLMemory(mean_entry, bn_prim_desc.mean_desc()); - auto variance_memory = BindDNNLMemory(variance_entry, bn_prim_desc.variance_desc()); - - // In DNNL, weight is composed of gamma+beta, so we point them to the same DNNL memory but - // assign an offset to beta data for runtime serialization. - auto weight_memory = BindDNNLMemory(gamma_entry, bn_prim_desc.weights_desc(), 0); - BindDNNLMemory(beta_entry, weight_memory, IC); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, - {DNNL_ARG_DST, out_memory}, - {DNNL_ARG_SCALE_SHIFT, weight_memory}, - {DNNL_ARG_MEAN, mean_memory}, - {DNNL_ARG_VARIANCE, variance_memory}}); + + // Concatenate scale and shift tensors + auto scale_shift_tr = TensorRequisite::AsIs(bn_prim_desc.weights_desc(), GenUniqueEid()); + auto sc_sh_dims = scale_shift_tr.dims(); + ICHECK(sc_sh_dims.size() == 2); + ICHECK(sc_sh_dims[0] == 2); + sc_sh_dims[0] /= 2; + auto scale_tr = scale_shift_tr.Crop(sc_sh_dims, {0, 0}).Squeeze(); + auto shift_tr = scale_shift_tr.Crop(sc_sh_dims, {1, 0}).Squeeze(); + + auto register_copy = [this](const TensorRequisite& src, const TensorRequisite& dst) { + dnnl::reorder::primitive_desc copy_pd(engine_, src.desc(), engine_, dst.desc()); + Submit(dnnl::reorder(copy_pd), {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}}); + }; + + register_copy(gamma_tr, scale_tr); + register_copy(beta_tr, shift_tr); + + Submit(dnnl::batch_normalization_forward(bn_prim_desc), {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_DST, dst_tr}, + {DNNL_ARG_SCALE_SHIFT, scale_shift_tr}, + {DNNL_ARG_MEAN, mean_tr}, + {DNNL_ARG_VARIANCE, var_tr}}); } void Pooling(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; + auto src_tr = GetInput(nid, 0); + auto dst_tr = GetOutput(nid, 0); + // Setup attributes. - auto data_entry = node.GetInputs()[0]; - JSONGraphNodeEntry out_entry(nid, 0); - dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; - std::vector str_kernel = node.GetAttr>("pool_size"); - std::vector str_strides = node.GetAttr>("strides"); - std::vector str_padding = node.GetAttr>("padding"); - std::vector str_padding_l(str_padding.begin(), - str_padding.begin() + str_padding.size() / 2); - std::vector str_padding_r(str_padding.end() - str_padding.size() / 2, - str_padding.end()); - std::vector str_dilates = node.GetAttr>("dilation"); - std::string layout = node.GetAttr>("layout")[0]; + auto strides = GetNodeAttr>(node, "strides"); + auto dilates = GetNodeAttr>(node, "dilation"); + auto padding = GetNodeAttr>(node, "padding"); + std::vector padding_l(padding.begin(), padding.begin() + padding.size() / 2); + std::vector padding_r(padding.begin() + padding.size() / 2, padding.end()); + auto kernel = GetNodeAttr>(node, "pool_size"); + auto src_layout = GetNodeAttr(node, "layout"); + auto dst_layout = GetNodeAttr(node, "out_layout"); + + // dst_layout == "" means to use data_layout + if (dst_layout.empty()) dst_layout = src_layout; + + // Minus one for DNNL representation. No dilation for DNNL is 0, for relay is 1. + for (auto& d : dilates) d--; + + // Take into account provided layout strings + src_tr = src_tr.TreatAs(src_layout); + dst_tr = dst_tr.TreatAs(dst_layout); // Attributes related to AvgPool if (algo == dnnl::algorithm::pooling_avg) { - int int_countpad = std::stoi(node.GetAttr>("count_include_pad")[0]); - bool count_include_pad = int_countpad != 0 ? true : false; - algo = count_include_pad ? dnnl::algorithm::pooling_avg_include_padding - : dnnl::algorithm::pooling_avg_exclude_padding; + auto include_pad = GetNodeAttr(node, "count_include_pad"); + algo = include_pad ? dnnl::algorithm::pooling_avg_include_padding + : dnnl::algorithm::pooling_avg_exclude_padding; } - dnnl::memory::dims src_dims = TransDims2Plain(input_shape, layout); - dnnl::memory::dims dst_dims = TransDims2Plain(out_shape, layout); - dnnl::memory::dims kernel_dims = TransformStr2Dims(str_kernel); - dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides); - dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true); - dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l); - dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r); - - // Memory descriptions. - auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); - auto pool_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(layout)); - auto pool_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any); - // Pooling description. - auto pool_desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_inference, algo, - pool_src_md, pool_dst_md, strides_dims, - kernel_dims, padding_dims_l, padding_dims_r); - - auto pool_prim_desc = dnnl::pooling_forward::primitive_desc(pool_desc, engine_, true); - auto pool = dnnl::pooling_forward(pool_prim_desc); - net_.push_back(pool); + auto pool_desc = dnnl::pooling_v2_forward::desc( + dnnl::prop_kind::forward_inference, algo, src_tr.desc(), //<= Do not use any for src tensor + dst_tr.LayoutAny().desc(), strides, kernel, dilates, padding_l, padding_r); + auto pool_prim_desc = dnnl::pooling_v2_forward::primitive_desc(pool_desc, engine_); - // Memories. - auto pool2d_src_memory = BindDNNLMemory(data_entry, pool_src_md); + src_tr = src_tr.RequestLayout(pool_prim_desc.src_desc()); + dst_tr = dst_tr.RequestLayout(pool_prim_desc.dst_desc()); - auto pool2d_dst_memory = BindDNNLMemory(out_entry, pool_prim_desc.dst_desc()); + auto scratchpad_tr = TensorRequisite::AsIs(pool_prim_desc.scratchpad_desc()); - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, pool2d_src_memory}, {DNNL_ARG_DST, pool2d_dst_memory}}); + Submit(dnnl::pooling_v2_forward(pool_prim_desc), + {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}, {DNNL_ARG_SCRATCHPAD, scratchpad_tr}}); } void Eltwise(const size_t& nid) { auto node = nodes_[nid]; auto op_name = node.GetOpName(); - auto algo = elt_name2algo[op_name]; + auto algo = elt_name2algo.at(op_name); + + auto src_tr = GetInput(nid, 0); + auto dst_tr = GetOutput(nid, 0); - auto data_entry = node.GetInputs()[0]; - dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); - dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype); float alpha = 0., beta = 0.; if (op_name == "clip") { - alpha = std::stof(node.GetAttr>("a_min")[0]); - beta = std::stof(node.GetAttr>("a_max")[0]); + alpha = GetNodeAttr(node, "a_min"); + beta = GetNodeAttr(node, "a_max"); } else if (op_name == "nn.leaky_relu") { - alpha = std::stof(node.GetAttr>("alpha")[0]); + alpha = GetNodeAttr(node, "alpha"); } - auto elt_desc = - dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, data_md, alpha, beta); + auto elt_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, + src_tr.desc(), alpha, beta); auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, engine_); - ICHECK(data_md == elt_prim_desc.dst_desc()); - - auto elt = dnnl::eltwise_forward(elt_prim_desc); - net_.push_back(elt); + ICHECK(src_tr.desc() == elt_prim_desc.dst_desc()); - auto data_memory = BindDNNLMemory(data_entry, data_md); - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, data_md); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}}); + Submit(dnnl::eltwise_forward(elt_prim_desc), {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}}); } void Softmax(const size_t& nid) { auto node = nodes_[nid]; - auto data_entry = node.GetInputs()[0]; - dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - int axis = std::stoi(node.GetAttr>("axis")[0]); + auto src_tr = GetInput(nid, 0); + auto dst_tr = GetOutput(nid, 0); + + auto axis = GetNodeAttr(node, "axis"); if (axis < 0) { - axis = shape.size() + axis; + axis = src_tr.dims().size() + axis; } - auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]); - dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype); auto softmax_desc = - dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, data_md, axis); + dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, src_tr.desc(), axis); auto softmax_prim_desc = dnnl::softmax_forward::primitive_desc(softmax_desc, engine_); - ICHECK(data_md == softmax_prim_desc.dst_desc()); - - auto softmax = dnnl::softmax_forward(softmax_prim_desc); - net_.push_back(softmax); + ICHECK(dst_tr.desc() == softmax_prim_desc.dst_desc()); - auto data_memory = BindDNNLMemory(data_entry, data_md); - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, data_md); - - net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}}); + Submit(dnnl::softmax_forward(softmax_prim_desc), + {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}}); } void Binary(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; + ICHECK_EQ(node.GetInputs().size(), 2U); // Memory and compute description. - std::vector data_dims; - std::vector data_mds; - std::vector data_memories; + auto lhs_tr = GetInput(nid, 0); + auto rhs_tr = GetInput(nid, 1); + auto dst_tr = GetOutput(nid, 0); - ICHECK_EQ(node.GetInputs().size(), 2U); - for (auto entry : node.GetInputs()) { - auto data_shape = nodes_[entry.id_].GetOpShape()[entry.index_]; - auto dtype = dtype_dl2dnnl(nodes_[entry.id_].GetOpDataType()[entry.index_]); - dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype); - - data_dims.push_back(data_shape); - data_mds.push_back(data_md); - data_memories.push_back(BindDNNLMemory(entry, data_md)); - } - ICHECK(data_dims[0] == data_dims[1]); - auto out_md = data_mds[0]; - JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, out_md); + lhs_tr = lhs_tr.Broadcast(dst_tr.dims()); + rhs_tr = rhs_tr.Broadcast(dst_tr.dims()); - auto binary_desc = dnnl::binary::desc(algo, data_mds[0], data_mds[1], out_md); + auto binary_desc = dnnl::binary::desc(algo, lhs_tr.desc(), rhs_tr.desc(), dst_tr.desc()); auto binary_prim_desc = dnnl::binary::primitive_desc(binary_desc, engine_); - auto binary = dnnl::binary(binary_prim_desc); - net_.push_back(binary); - net_args_.push_back({{DNNL_ARG_SRC_0, data_memories[0]}, - {DNNL_ARG_SRC_1, data_memories[1]}, - {DNNL_ARG_DST, out_memory}}); + Submit(dnnl::binary(binary_prim_desc), + {{DNNL_ARG_SRC_0, lhs_tr}, {DNNL_ARG_SRC_1, rhs_tr}, {DNNL_ARG_DST, dst_tr}}); + } + + template ::value, int> = 0> + T AttrConvert(std::vector val) { + ICHECK_EQ(val.size(), 1); + return std::stol(val[0]); + } + + template ::value, int> = 0> + T AttrConvert(std::vector val) { + ICHECK_EQ(val.size(), 1); + return std::stof(val[0]); + } + + template ::value, int> = 0> + T AttrConvert(std::vector val) { + ICHECK_EQ(val.size(), 1); + return val[0]; + } + + template >::value, int> = 0> + T AttrConvert(std::vector val) { + T res; + for (const auto& el : val) res.push_back(AttrConvert({el})); + return res; + } + + /*! + * \brief Helper to extract node attribute with ability to specify default value and result type. + */ + template + const T GetNodeAttr(const json::JSONGraphNode& node, std::string name, + std::vector def = {}) { + auto attr = node.HasAttr(name) ? node.GetAttr>(name) : def; + return AttrConvert(attr); } - // Read from DNNL memory (+offset) and write to the handle. - inline void read_from_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size, - size_t offset = 0) { - uint8_t* src = static_cast(mem.get_data_handle()); - std::copy(src + offset, src + offset + size, static_cast(handle)); + TensorRequisite GetInput(const size_t& nid, const int idx) { + if (idx == -1) return {}; // -1 reserved value for empty input. + + const JSONGraphNode& node = nodes_[nid]; + + ICHECK_LT(idx, node.GetInputs().size()); + auto data_entry = node.GetInputs()[idx]; + + auto shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; + auto dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]; + auto eid = node_row_ptr_[data_entry.id_] + data_entry.index_; + auto const_dl_tensor = data_entry_[eid]; + + auto desc = MakePlainDesc(shape, dtype); + + TensorRequisite res; + if (const_dl_tensor) { + ICHECK(const_dl_tensor->data); + ICHECK(const_dl_tensor->strides == nullptr); + auto mem = dnnl::memory(desc, engine_, const_dl_tensor->data); + res = TensorRequisite::AsIs(mem, eid); + } else { + res = TensorRequisite::AsIs(desc, eid); + } + return res; } - // Read from the handle and write to DNNL memory (+offset). - inline void write_to_dnnl_memory(void* handle, const dnnl::memory& mem, size_t size, - size_t offset = 0) { - uint8_t* dst = static_cast(mem.get_data_handle()); - std::copy(reinterpret_cast(handle), reinterpret_cast(handle) + size, - dst + offset); + TensorRequisite GetOutput(const size_t& nid, const int idx) { + if (idx == -1) return {}; // -1 reserved value for empty input. + + const JSONGraphNode& node = nodes_[nid]; + + ICHECK_LT(idx, node.GetNumOutput()); + auto shape = node.GetOpShape()[idx]; + auto dtype = node.GetOpDataType()[idx]; + auto eid = node_row_ptr_[nid] + static_cast(idx); + + ICHECK(data_entry_[eid] == nullptr); + auto desc = MakePlainDesc(shape, dtype); + + return TensorRequisite::AsIs(desc, eid).Backward(); } - // Generate DNNL memory description and infer the data layout by the given shape. - inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& shape, dt dtype) { - dnnl::memory::desc data_md; - switch (shape.size()) { - case 2: - data_md = dnnl::memory::desc({shape, dtype, tag::ab}); - break; - case 3: - data_md = dnnl::memory::desc({shape, dtype, tag::abc}); - break; - case 4: - data_md = dnnl::memory::desc({shape, dtype, tag::abcd}); - break; - case 5: - data_md = dnnl::memory::desc({shape, dtype, tag::abcde}); - break; - default: - LOG(FATAL) << "Unsupported data shape dimension: " << shape.size(); - break; + /*! \brief Helper function to register primitive into execution queue */ + void Submit(const dnnl::primitive& prim, + const std::unordered_map& tr_args) { + // Register all provided TR arguments + std::unordered_map prim_arg_id; + TensorRegistry::ActionQue post_prim_actions; + for (const auto& kvp : tr_args) { + const auto& key = kvp.first; + const auto& tr = kvp.second; + + if (!tr.defined()) continue; // empty arg is admitted. Just skip it + auto arg_id = tensor_registry_.Register(tr, tr.IsReversed() ? &post_prim_actions : &net_); + prim_arg_id[key] = arg_id; } - return data_md; + + // Register main primitive + net_.push_back({prim, prim_arg_id}); + + // Register post actions + net_.insert(net_.end(), post_prim_actions.begin(), post_prim_actions.end()); } + uint32_t GenUniqueEid() { return next_unique_eid_offset_++; } + /* The dnnl engine. */ dnnl::engine engine_; /* The dnnl stream. */ dnnl::stream stream_; /* The network layers that are represented in dnnl primitives. */ - std::vector net_; - /* The memory that is consumed by arguments. */ - std::vector> net_args_; - /* The entry ID to its corresponding output memory. */ - std::unordered_map> entry_out_mem_; + TensorRegistry::ActionQue net_; + /* Storage for all memory objects */ + TensorRegistry tensor_registry_; + /* Generator of new unique eid which doesn't match with existing data entry */ + uint32_t next_unique_eid_offset_; + /* Map of Run arg idx to corresponding eid */ + std::vector run_arg_eid_; }; runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, diff --git a/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h new file mode 100644 index 000000000000..d02ceff5de82 --- /dev/null +++ b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h @@ -0,0 +1,720 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/dnnl/dnnl_tensor_requisite.cc + * \brief Helper TR wrapper to simplify tensors processing + */ + +#ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_ +#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO(@apeskov): Have to mute warning from dnnl headers. +// -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command +#include + +#include "dnnl_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace utils; + +/*! + * \brief Helper object to simplify tensor transformation description. + * + * Allow to specify original source tensor and future actions which should be applied to it. + * Can be treated as sequence of reordering or reinterpretation of original source tensor. + * Finally TR can be solved as proper interpretation of source memory buffer, or sequence of + * dnnl::reorder operators which will provide desired data. + * + * \note Empty TR object allow any manipulation. Empty TR will be returned. + * + * \sa TensorRegistry + * + * Example: + * \code + * dnnl::memory src_mem = ...; // 5D tensor, shape {5, 2, 128, 128, 8} + * + * // Construct TR + * auto tr = TensorRequisite.AsIs(src_mem, eid); // 5D + * + * // describe sequence of layout transformation + * tr = tr.TreatAs("ABCD8b"); // 4D + * tr = tr.Permute({0, 2, 3, 1}); // Permute axes NCHW -> NHWC + * tr = tr.Crop({1, 128, 128, 16}, {0, 0, 0}); // extract first batch element + * tr = tr.Squeeze(); // 1D + * + * // register TR + * TensorRegistry t_reg; + * auto t_id = t_reg.register(tr); + * + * // Get final dnnl::memory object + * auto solver = t_reg.MakeSolver(ext_tensor_provider); + * auto mem = solver(t_id); + * \endcode + * + */ +class TensorRequisite { + public: + using Tid = uint32_t; + static constexpr Tid kUndefinedTid = std::numeric_limits::max() - 1; + + /*! \brief Empty constructor */ + TensorRequisite() {} + + /*! \brief Construct TR on top of existing memory object */ + static TensorRequisite AsIs(const dnnl::memory& mem, Tid id = kUndefinedTid) { + auto res = AsIs(mem.get_desc(), id); + if (mem.get_data_handle() != nullptr) res.mem_ = mem; + return res; + } + + /*! \brief Construct TR on top of existing memory descriptor object */ + static TensorRequisite AsIs(const dnnl::memory::desc& desc, Tid id = kUndefinedTid) { + return {desc, {}, false, {}, id, false}; + } + + /*! \brief return logical shape of tensor */ + dnnl::memory::dims dims() const { return t_desc_.dims(); } + + /*! \brief return data type of tensor */ + dnnl::memory::data_type data_type() const { return t_desc_.data_type(); } + + /*! \brief return tensor desc */ + dnnl::memory::desc desc() const { return t_desc_; } + + /*! \brief Make TR with backward dataflow */ + TensorRequisite Backward() const { + if (!defined()) return *this; + ICHECK(orig_ == nullptr); + return {t_desc_, orig_, reinterpret_, mem_, eid_, true}; + } + + /*! \brief Produce TR with permuted axes */ + TensorRequisite Permute(const std::vector& permutation) const { + if (!defined()) return *this; // nothing for empty TR + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.permute_axes(permutation); + return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Produce TR with reinterpret data of original tr */ + TensorRequisite Reshape(const dnnl::memory::dims& shape) const { + if (!defined()) return *this; // nothing for empty TR + if (t_desc_.dims() == shape) return *this; + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.reshape(shape); + return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Produce TR with broadcasted values */ + TensorRequisite Broadcast(const dnnl::memory::dims& shape) const { + if (!defined()) return *this; // nothing for empty TR + if (t_desc_.dims() == shape) return *this; + ICHECK(!reverse_data_flow_); + + auto orig = std::make_shared(*this); + + // numpy like broadcast + auto extended_dims = t_desc_.dims(); + auto one_filled = dnnl::memory::dims(shape.size() - extended_dims.size(), 1); + extended_dims.insert(extended_dims.begin(), one_filled.begin(), one_filled.end()); + auto desc = t_desc_.reshape(extended_dims); + for (size_t i = 0; i < extended_dims.size(); i++) { + if (extended_dims[i] == shape[i]) continue; + ICHECK(extended_dims[i] == 1); + ICHECK(desc.data.dims[i] == desc.data.padded_dims[i]); + + desc.data.dims[i] = shape[i]; + desc.data.padded_dims[i] = shape[i]; + desc.data.format_desc.blocking.strides[i] = 0; + } + + // reinterpret memory buffer with new strides + return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Produce TR with sub memory view (ROI) */ + TensorRequisite Crop(const dnnl::memory::dims& shape, const dnnl::memory::dims& offset) const { + if (!defined()) return *this; // nothing for empty TR + + ICHECK_EQ(shape.size(), t_desc_.dims().size()); + ICHECK_EQ(offset.size(), t_desc_.dims().size()); + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.submemory_desc(shape, offset, /*allow_empty=*/true); + + // Originally DNNL implementation is very limited. Let's slightly enhance it. + if (!desc && t_desc_.data.format_kind == dnnl_blocked) { + bool offset_is_zero = + std::all_of(offset.begin(), offset.end(), [](auto el) { return el == 0; }); + + dnnl::memory::dims block_sizes(t_desc_.dims().size(), 1); + for (int i = 0; i < t_desc_.data.format_desc.blocking.inner_nblks; i++) + block_sizes[t_desc_.data.format_desc.blocking.inner_idxs[i]] *= + t_desc_.data.format_desc.blocking.inner_blks[i]; + + bool shape_reduction_less_than_block = true; + for (int i = 0; i < t_desc_.data.ndims; i++) { + shape_reduction_less_than_block &= t_desc_.data.dims[i] - shape[i] < block_sizes[i]; + } + + // This is auto padded case. Just update dims value. + if (offset_is_zero && shape_reduction_less_than_block) { + desc = t_desc_; + std::copy(shape.begin(), shape.end(), desc.data.dims); + } + } + + ICHECK(desc); + + return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Produce TR with squeeze shape */ + TensorRequisite Squeeze(const dnnl::memory::dims& dims_to_squeeze = {}) const { + if (!defined()) return *this; // nothing for empty TR + + dnnl::memory::dims squeezed_dims; + if (dims_to_squeeze.empty()) { + for (auto d : t_desc_.dims()) + if (d != 1) squeezed_dims.push_back(d); + } else { + for (size_t i = 0; i < t_desc_.dims().size(); i++) + if (std::find(dims_to_squeeze.begin(), dims_to_squeeze.end(), i) == dims_to_squeeze.end()) + squeezed_dims.push_back(t_desc_.dims()[i]); + } + + if (squeezed_dims.empty()) squeezed_dims = {1}; + + auto orig = std::make_shared(*this); + // reinterpret memory buffer with new strides + auto desc = t_desc_.reshape(squeezed_dims); + return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Produce TR with specified layout descriptor */ + TensorRequisite RequestLayout(dnnl::memory::desc desc) const { + if (!defined()) return *this; // nothing for empty TR + + // If it's the same desc just return self + if (desc == t_desc_) return *this; + + ICHECK(t_desc_.dims() == desc.dims()) << "Requested layout is not compatible with " + "presented shape"; + + auto orig = std::make_shared(*this); + return {desc, orig, false, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Define which logical dims ordering is default for particular layout string. */ + static std::string DefaultLogicLayoutFor(const std::string& layout) { + // Rank is all non digit marked dims + auto it = layout.begin(); + while (it != layout.end() && !std::isdigit(*it)) it++; + int rank = std::distance(layout.begin(), it); + + static const std::vector sparse_dims = {"W", "HW", "DHW"}; + if (layout.find("N") != std::string::npos) return "NC" + sparse_dims[rank - 3]; + if (layout.find("G") != std::string::npos) return "GOI" + sparse_dims[rank - 4]; + if (layout.find("O") != std::string::npos) return "OI" + sparse_dims[rank - 3]; + + LOG(FATAL) << "Unknown layout " << layout << "There is no default scheme to handle it"; + return {}; + } + + /*! + * \brief Treat TR shape as described in layout string. + * + * Blocked dimensions will be concatenated and put into proper shape position corresponding to . + * resulting_layout_logic argument. If desired logic layout was not provided it will be deduced + * automatically based on some internal heuristics. + * + * Limitation 1. Blocking dims should be dense. Dims marked with digits use natural strides. + * Limitation 2. Blocking dims are innermost. Dims marked like 8c, 4o goes after regular + * dimensions. NC8cHW4h4cD is not valid tensor in terms of DNNL. And cannot be + * achieved with memory reinterpretation, so data copy is required. Proper layout + * looks like NCHWD_8c4h4c, first part is outer dims, second digits marked part is + * innermost. + */ + TensorRequisite TreatAs(const std::string& layout, std::string desired_logic_layout = "") const { + if (desired_logic_layout.empty()) desired_logic_layout = DefaultLogicLayoutFor(layout); + + const auto origin_dims = dims(); + + // split layout string to tokens {size, tag} like {16, 'C'}, {4, 'O'} + std::vector> layout_tokens; + for (auto it = layout.begin(); it != layout.end();) { + auto start = it; + while (std::isdigit(*it)) it++; + int blk_size = start == it ? -1 : std::stoi(std::string{start, it}); + layout_tokens.push_back({blk_size, std::toupper(*it)}); + it++; + } + + // check applicability of layout + auto it = layout_tokens.begin(); + while (it != layout_tokens.end() && it->first == -1) it++; + int rank = std::distance(layout_tokens.begin(), it); + while (it != layout_tokens.end()) { + ICHECK_NE(it->first, -1) << "DNNL limitation. Blocking dims should be innermost. " + << "But received layout is " << layout; + it++; + } + + ICHECK_EQ(layout_tokens.size(), origin_dims.size()); + ICHECK_EQ(rank, desired_logic_layout.size()) << layout; + + std::vector> outermost_tokens(layout_tokens.begin(), + layout_tokens.begin() + rank); + std::vector> innermost_tokens(layout_tokens.begin() + rank, + layout_tokens.end()); + // define dim resulting dim positions + std::map dim_position_by_tag; + for (size_t i = 0; i < desired_logic_layout.size(); i++) + dim_position_by_tag[std::toupper(desired_logic_layout[i])] = i; + + // Construct resulting desc by modifying original one + dnnl::memory::desc res_desc = t_desc_; + + memset(&res_desc.data.format_desc.blocking, 0, sizeof(res_desc.data.format_desc.blocking)); + std::fill(res_desc.data.dims, res_desc.data.dims + DNNL_MAX_NDIMS, 0); + std::fill(res_desc.data.padded_dims, res_desc.data.padded_dims + DNNL_MAX_NDIMS, 0); + + res_desc.data.ndims = rank; + res_desc.data.format_desc.blocking.inner_nblks = innermost_tokens.size(); + + auto res_dims = res_desc.data.dims; + auto res_strides = res_desc.data.format_desc.blocking.strides; + auto res_inner_blks = res_desc.data.format_desc.blocking.inner_blks; + auto res_inner_idxs = res_desc.data.format_desc.blocking.inner_idxs; + + std::fill(res_dims, res_dims + rank, 1); + + int orig_dim_idx = 0; + for (const auto& p : outermost_tokens) { + auto tag = p.second; + auto dim_size = origin_dims[orig_dim_idx]; + + auto result_dim_position = dim_position_by_tag[tag]; + res_dims[result_dim_position] *= dim_size; + res_strides[result_dim_position] = t_desc_.data.format_desc.blocking.strides[orig_dim_idx]; + orig_dim_idx++; + } + for (const auto& p : innermost_tokens) { + auto tag = p.second; + auto dim_size = origin_dims[orig_dim_idx]; + auto result_dim_position = dim_position_by_tag[tag]; + ICHECK_EQ(p.first, dim_size) + << "Blocking layout is not applicable to tensor with shape: " << origin_dims + << ". Requested layout is " << layout; + + res_dims[result_dim_position] *= dim_size; + *res_inner_blks++ = dim_size; + *res_inner_idxs++ = result_dim_position; + orig_dim_idx++; + } + + // Assume tensor is dense. There is no additional padding. + std::copy(res_desc.data.dims, res_desc.data.dims + rank, res_desc.data.padded_dims); + + if (t_desc_ == res_desc) return *this; + + auto orig = std::make_shared(*this); + return {res_desc, orig, true, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! + * \brief Produce TR with unspecified layout. + * + * Cannot be registered in TensorRegistry. Only for querying DNNL for preferred layouts. + */ + TensorRequisite LayoutAny() const { + auto orig = std::make_shared(*this); + // Recreate tensor desc with layout 'any' + dnnl::memory::desc any_desc{t_desc_.dims(), t_desc_.data_type(), dnnl::memory::format_tag::any}; + return {any_desc, orig, false, {}, kUndefinedTid, reverse_data_flow_}; + } + + /*! \brief Check is TR is constant. */ + bool IsConstant() const { + if (orig_) return orig_->IsConstant(); + return mem_.operator bool(); + } + + /*! \brief Check is tensor is scalar. */ + bool IsScalar() const { return t_desc_.dims().size() == 1 && t_desc_.dims()[0] == 1; } + + /*! \brief Return const data memory if available. */ + dnnl::memory GetConstData() const { + if (mem_) return mem_; + if (!orig_) return {}; + + if (auto orig_const_data = orig_->GetConstData()) { + if (reinterpret_) { + return {t_desc_, orig_const_data.get_engine(), orig_const_data.get_data_handle()}; + } else { + auto eng = orig_const_data.get_engine(); + auto res = dnnl::memory{t_desc_, eng}; + dnnl::reorder(orig_const_data, res).execute(dnnl::stream(eng), orig_const_data, res); + return res; + } + } + return {}; + } + + /*! + * \brief Return const data memory in form of vector. + * + * Same as GetConstData but use std::vector instead of dnnl::memory. Works only for 1D tensor + * and scalar TRs. Useful for specification of 1D DNNL attributes like zero_point or + * per_channel_scale + */ + template + std::vector GetConstDataLikeVec() const { + auto const_data = GetConstData(); + auto desc = const_data.get_desc(); + ICHECK(desc.data_type() == utils::DnnlDType()); + ICHECK(desc.dims().size() == 1); + + auto size = desc.get_size() / sizeof(T); + auto ptr = static_cast(const_data.get_data_handle()); + + return std::vector(ptr, ptr + size); + } + + /*! \brief Get value of constant scalar tensor if possible. */ + template + T GetConstScalarData() const { + ICHECK(IsConstant()); + ICHECK(IsScalar()); + auto const_data = GetConstData(); + auto desc = const_data.get_desc(); + ICHECK(desc.data_type() == utils::DnnlDType()); + + auto ptr = static_cast(const_data.get_data_handle()); + return *ptr; + } + + /*! \brief Check if tensor is not empty. */ + bool defined() const { return !t_desc_.is_zero(); } + + /*! \brief Same as defined */ + operator bool() const { return defined(); } + + /*! + * \brief Check if tensor represent a reversed data flow. + * Useful for describing output processing + */ + bool IsReversed() const { return reverse_data_flow_; } + + private: + TensorRequisite(const dnnl::memory::desc& t_desc, const std::shared_ptr& orig, + bool reinterpret, const dnnl::memory& const_mem, uint32_t eid, + bool reverse_data_flow) + : t_desc_(t_desc), + orig_(orig), + reinterpret_(reinterpret), + mem_(const_mem), + eid_(eid), + reverse_data_flow_(reverse_data_flow) { + if (mem_) ICHECK(!orig_ && !reverse_data_flow_ && eid_ == kUndefinedTid); + if (eid_ != kUndefinedTid) ICHECK(!orig_); + } + + /* Descriptor of particular tensor */ + dnnl::memory::desc t_desc_ = {}; + /* Parent TR object which is referred from this TR */ + std::shared_ptr orig_ = {}; + /* Flag to specify which action should be done with orig TR, reordering or reinterpretation */ + bool reinterpret_ = false; + /* Const memory object if available */ + dnnl::memory mem_ = {}; + /* Entry ID of tensor if available */ + uint32_t eid_ = kUndefinedTid; + + /* + * Flag to describe reverse data flow case + * All operation on queue will be executed in reverse order. Actual for dst tensor description + */ + bool reverse_data_flow_ = false; + + friend class TensorRegistry; +}; + +/*! + * \brief The registry of tensors. Implement matching of provided TRs and real memory buffers. + * + * Registration of TR performed by calling method Register(), which will return ArgId object. + * ArgId can be mapped to real memory via memory solver created by method MakeSolver(). + */ +class TensorRegistry { + private: + enum ArgReqFlag { + CONST, /// < Constant tensor. ExecutionCTX independent + TMP_STORAGE, /// < Intermediate tensors. Stored inside TensorRegistry. Inaccessible outside + EXT_EID, /// < External data. Input or Output. + }; + + public: + struct ArgId { + TensorRegistry::ArgReqFlag flag_; + uint32_t idx_; + }; + + using Action = std::tuple>; + using ActionQue = std::vector; + using DLTensorProvider = std::function; + using MemSolver = std::function; + + TensorRegistry() = default; + TensorRegistry(const dnnl::engine& eng, const std::set& ext_io_eid) + : tmp_mem_collection_(1), ext_io_eid_(ext_io_eid), eng_(eng), stream_(eng) {} + + /*! + * \brief Register TR to registry + * + * Resolution of TR may lead to introduction of intermediate memory buffers and additional + * transformation actions which should be performed before or after usage of corresponding memory + * buffer. Additional actions will be append to provided actions queue. Corresponding to + * tr.IsReversed() value actions should be executed before or after usage of resulting ArgId. + * + * \param tr tensor requisite sequence to register + * \param action resulting action queue. If TR resolution is required execution of some + * transformation actions they will be put here + * \return associated ArgId. Should be used as argument for MemSolver. + */ + ArgId Register(const TensorRequisite& tr, ActionQue* action) { + // 1) Constant tensor. Direct reference + if (auto const_data = tr.GetConstData()) { + auto idx = const_mem_collection_.size(); + const_mem_collection_.push_back(const_data); + return MakeArgReq(ArgReqFlag::CONST, static_cast(idx)); + } + + // 2) EID mapped tensor. Direct reference + if (tr.eid_ != TensorRequisite::kUndefinedTid) { + if (ext_io_eid_.count(tr.eid_) == 0) { // Not IO tensor, means it's intermediate + if (eid2idx_tmp_.count(tr.eid_)) { + auto idx = eid2idx_tmp_.at(tr.eid_); + return MakeArgReq(ArgReqFlag::TMP_STORAGE, idx); + } else { + // register himself + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(tr.t_desc_); + eid2idx_tmp_[tr.eid_] = idx; + return MakeArgReq(ArgReqFlag::TMP_STORAGE, static_cast(idx)); + } + } else { + auto idx = ext_mem_collection_.size(); + ext_mem_collection_.push_back({tr.eid_, tr.t_desc_}); + return MakeArgReq(ArgReqFlag::EXT_EID, static_cast(idx)); + } + } + + // 3) Tensors with transform actions + if (tr.orig_) { + // recursive register of orig TR + auto orig_arg_req = Register(*tr.orig_, action); + if (tr.reinterpret_) { + return RegisterReinterpret(orig_arg_req, tr.t_desc_); + } else { + return RegisterReorder(orig_arg_req, tr.t_desc_, tr.reverse_data_flow_, action); + } + } + + // 4) Scratchpad + ICHECK(!tr.orig_ && !tr.mem_ && tr.eid_ == TensorRequisite::kUndefinedTid); + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(tr.t_desc_); + tmp_mem_mapping_[idx] = 0; // zero position tmp mem object is reserved for scratchpads + + auto scratchpad_size = tr.t_desc_.get_size(); + auto glob_scratchpad_size = tmp_mem_collection_[0].get_size(); + if (scratchpad_size > glob_scratchpad_size) { + tmp_mem_collection_[0] = + dnnl::memory::desc({static_cast(scratchpad_size)}, + dnnl::memory::data_type::u8, dnnl::memory::format_tag::a); + } + return MakeArgReq(TMP_STORAGE, static_cast(idx)); + } + + /*! + * \brief Construct memory solver for all registered TRs. + * \param ext_provider callback to resolve external IO buffers + * \return memory solver object to match ArgId to dnnl::memory objects + */ + MemSolver MakeSolver(const DLTensorProvider& ext_provider) const { + return MemSolverImpl(eng_, ext_provider, const_mem_collection_, ext_mem_collection_, + tmp_mem_collection_, tmp_mem_mapping_); + } + + private: + ArgId RegisterReinterpret(ArgId src_ar, const dnnl::memory::desc& desc) { + switch (src_ar.flag_) { + case TMP_STORAGE: { + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(desc); + tmp_mem_mapping_[idx] = src_ar.idx_; + return MakeArgReq(TMP_STORAGE, idx); + } + case EXT_EID: { + auto ext_req = ext_mem_collection_[src_ar.idx_]; + auto idx = ext_mem_collection_.size(); + ext_mem_collection_.push_back({ext_req.first, desc}); + return MakeArgReq(EXT_EID, idx); + } + default: + LOG(FATAL) << "Unknown case"; + } + return {}; + } + + ArgId RegisterReorder(ArgId src_ar, const dnnl::memory::desc& desc, bool reverse_data_flow, + ActionQue* action) { + ICHECK(src_ar.flag_ == TMP_STORAGE || src_ar.flag_ == EXT_EID); + + auto src_desc = src_ar.flag_ == TMP_STORAGE ? tmp_mem_collection_[src_ar.idx_] + : ext_mem_collection_[src_ar.idx_].second; + auto idx = tmp_mem_collection_.size(); + tmp_mem_collection_.push_back(desc); + auto dst_ar = MakeArgReq(TMP_STORAGE, idx); + + // reorder action submit + if (reverse_data_flow) { + auto reorder_pd = dnnl::reorder::primitive_desc(eng_, desc, eng_, src_desc); + action->insert(action->begin(), + {dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, dst_ar}, {DNNL_ARG_TO, src_ar}}}); + } else { + auto reorder_pd = dnnl::reorder::primitive_desc(eng_, src_desc, eng_, desc); + action->push_back( + {dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, src_ar}, {DNNL_ARG_TO, dst_ar}}}); + } + return dst_ar; + } + /*! \brief Implementation of memory solver */ + class MemSolverImpl { + public: + MemSolverImpl(const dnnl::engine& eng, const DLTensorProvider& ext_data_provider, + const std::vector& const_mems, + const std::vector>& ext_mems, + const std::vector& tmp_mem_descs, + const std::map& tmp_mem_mapping) + : eng_(eng), + ext_data_provider_(ext_data_provider), + const_mems_(const_mems), + ext_mems_(ext_mems) { + // Construct temp memory objects on the fly. While we have no scratchpads + // support on VM/GraphExecutor level. + tmp_mems_.resize(tmp_mem_descs.size()); + for (size_t i = 0; i < tmp_mem_descs.size(); i++) { + auto found = tmp_mem_mapping.find(i); + + if (found != tmp_mem_mapping.end()) { + auto reuse_hdl = tmp_mems_[found->second].get_data_handle(); + tmp_mems_[i] = dnnl::memory(tmp_mem_descs[i], eng_, reuse_hdl); + } else { + tmp_mems_[i] = dnnl::memory(tmp_mem_descs[i], eng_); + } + } + } + + /*! \brief Find memory object associated with provided ArgId */ + dnnl::memory operator()(const ArgId& ar) const { + switch (ar.flag_) { + case CONST: + return const_mems_.at(ar.idx_); + case TMP_STORAGE: + return tmp_mems_.at(ar.idx_); + case EXT_EID: { + auto eid_and_desc = ext_mems_.at(ar.idx_); + auto eid = eid_and_desc.first; + auto desc = eid_and_desc.second; + + auto ext_dl_tensor = ext_data_provider_(eid); + ICHECK(ext_dl_tensor->data); + return dnnl::memory{desc, eng_, ext_dl_tensor->data}; + } + } + return {}; + } + + private: + const dnnl::engine& eng_; + const DLTensorProvider& ext_data_provider_; + const std::vector& const_mems_; + const std::vector>& ext_mems_; + std::vector tmp_mems_; + }; + + ArgId MakeArgReq(ArgReqFlag flag, uint32_t idx) { return {flag, idx}; } + + /* Collection of const memory objects. */ + std::vector const_mem_collection_; + + /* Collection of intermediate memory descriptors. Zero position is reserved for scratchpads. */ + std::vector tmp_mem_collection_; + + /* Mapping of some temp buffer on previously registered. */ + std::map tmp_mem_mapping_; + + /* Collection of external_intermediate memory objects. + * first - eid of external buffer to ask + * second - t_desc describes how to treat external buffer */ + std::vector> ext_mem_collection_; + + /* Map of eid to index of temp buffer in tmp_mem_collection_ */ + std::unordered_map eid2idx_tmp_; + + /* List of external eid */ + std::set ext_io_eid_; + + /* Engine of all tensors existing in this registry */ + dnnl::engine eng_; + + /* Execution stream use to reorder const data */ + dnnl::stream stream_; +}; + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_DNNL_DNNL_TENSOR_REQUISITE_H_ diff --git a/src/runtime/contrib/dnnl/dnnl_utils.cc b/src/runtime/contrib/dnnl/dnnl_utils.cc index 7e79f1c939cf..23992209f2ad 100644 --- a/src/runtime/contrib/dnnl/dnnl_utils.cc +++ b/src/runtime/contrib/dnnl/dnnl_utils.cc @@ -23,11 +23,14 @@ #include "dnnl_utils.h" +#include "tvm/runtime/logging.h" + namespace tvm { namespace runtime { namespace contrib { -using dt = dnnl::memory::data_type; -dt dtype_dl2dnnl(DLDataType dltype) { + +dnnl::memory::data_type dtype_dl2dnnl(DLDataType dltype) { + using dt = dnnl::memory::data_type; dt dnnl_type = dt::undef; if (dltype.code == DataType::TypeCode::kFloat) { if (dltype.bits == 16) { @@ -51,6 +54,23 @@ dt dtype_dl2dnnl(DLDataType dltype) { } return dnnl_type; } + +dnnl::memory::dims shape_dl2dnnl(const std::vector& shape) { + if (shape.empty()) return {1}; // DNNL scalar representation is 1D tensor + return shape; +} + +dnnl::memory::desc MakePlainDesc(const std::vector& shape, DLDataType dltype) { + auto dnnl_shape = shape_dl2dnnl(shape); + auto dnnl_dtype = dtype_dl2dnnl(dltype); + + auto dnnl_plain_strides = dnnl::memory::dims(dnnl_shape.size(), 1); + for (int i = dnnl_shape.size() - 2; i >= 0; i--) + dnnl_plain_strides[i] = dnnl_plain_strides[i + 1] * dnnl_shape[i + 1]; + + return {dnnl_shape, dnnl_dtype, dnnl_plain_strides}; +} + } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/dnnl/dnnl_utils.h b/src/runtime/contrib/dnnl/dnnl_utils.h index 4fb236f96f8b..a598b6704450 100644 --- a/src/runtime/contrib/dnnl/dnnl_utils.h +++ b/src/runtime/contrib/dnnl/dnnl_utils.h @@ -18,16 +18,23 @@ */ /*! - * \file src/runtime/contrib/dnnl/dnnl_utils.h - * \brief utils for DNNL. + * \file src/runtime/contrib/dnnl/dnnl_utils.cc + * \brief Some DNNL specific utility functions */ #ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_ #define TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_ -#include +#include +#include +#include +#include -#include "dnnl.hpp" +// TODO(@apeskov): Have to mute warning from dnnl headers. +// -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command +#include + +#include "tvm/runtime/data_type.h" namespace tvm { namespace runtime { @@ -40,7 +47,90 @@ namespace contrib { */ dnnl::memory::data_type dtype_dl2dnnl(DLDataType dltype); +/*! + * \brief Converter TVM shape to DNNL dims + * \param shape tvm shape + * \return dims in terms of dnnl + */ +dnnl::memory::dims shape_dl2dnnl(const std::vector& shape); + +/*! + * \brief Construct plain tensor descriptor + * \param shape provided shape + * \param dltype provided data type + * \return resulting plain tensor desc + */ +dnnl::memory::desc MakePlainDesc(const std::vector& shape, DLDataType dltype); + +namespace utils { + +/*! \brief Pretty printer util for shape */ +inline std::ostream& operator<<(std::ostream& o, const dnnl::memory::dims& dims) { + o << "["; + auto d = dims.begin(); + if (d != dims.end()) o << *d++; + while (d != dims.end()) o << "," << *d++; + o << "]"; + return o; +} + +/*! \brief Pretty printer util for data type */ +inline std::ostream& operator<<(std::ostream& o, const dnnl::memory::data_type& type) { + std::string name = "undef"; + switch (type) { + case dnnl::memory::data_type::undef: + name = "undef"; + break; + case dnnl::memory::data_type::f32: + name = "fp32"; + break; + case dnnl::memory::data_type::f16: + name = "fp16"; + break; + case dnnl::memory::data_type::bf16: + name = "bf16"; + break; + case dnnl::memory::data_type::s32: + name = "i32"; + break; + case dnnl::memory::data_type::s8: + name = "i8"; + break; + case dnnl::memory::data_type::u8: + name = "u8"; + break; + } + o << name; + return o; +} + +/*! \brief Converter data type template arg to runtime object */ +template +inline dnnl::memory::data_type DnnlDType(); + +template <> +inline dnnl::memory::data_type DnnlDType() { + return dnnl::memory::data_type::s32; +} + +template <> +inline dnnl::memory::data_type DnnlDType() { + return dnnl::memory::data_type::f32; +} + +template <> +inline dnnl::memory::data_type DnnlDType() { + return dnnl::memory::data_type::u8; +} + +template <> +inline dnnl::memory::data_type DnnlDType() { + return dnnl::memory::data_type::s8; +} + +} // namespace utils } // namespace contrib } // namespace runtime } // namespace tvm + #endif // TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_