Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
sunwayforever committed Sep 26, 2021
1 parent 031241e commit da11a6c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
9 changes: 3 additions & 6 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,9 @@ class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public C

using ArgFunType = std::function<std::vector<std::string>(const CallNode*)>;
static const std::map<std::string, std::pair<std::string, ArgFunType>> op_map = {
{"nn.conv2d", {"dnnl_conv2d", Conv2d}},
{"nn.dense", {"dnnl_dense", Dense}},
{"nn.relu", {"dnnl_relu", Relu}},
{"nn.batch_norm", {"dnnl_bn", BatchNorm}},
{"add", {"dnnl_binary_op", Add}},
{"multiply", {"dnnl_binary_op", Multiply}},
{"nn.conv2d", {"dnnl_conv2d", Conv2d}}, {"nn.dense", {"dnnl_dense", Dense}},
{"nn.relu", {"dnnl_relu", Relu}}, {"nn.batch_norm", {"dnnl_bn", BatchNorm}},
{"add", {"dnnl_binary_op", Add}}, {"multiply", {"dnnl_binary_op", Multiply}},
};

const auto op_name = GetRef<Op>(op_node)->name;
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/contrib/dnnl/dnnl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int
read_from_dnnl_memory(out, dst_memory);
}

extern "C" void dnnl_relu(float* data, float* out, std::vector<long int> shape) {
extern "C" void dnnl_relu(float* data, float* out, std::vector<int64_t> shape) {
using dt = memory::data_type;

engine eng(engine::kind::cpu, 0);
Expand Down Expand Up @@ -265,7 +265,7 @@ extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, flo
}

extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_type,
std::vector<long int> shape) {
std::vector<int64_t> shape) {
using dt = memory::data_type;

engine eng(engine::kind::cpu, 0);
Expand All @@ -287,7 +287,7 @@ extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_
default:
assert(true);
break;
};
}

auto add_desc = binary::desc(algo, data_md, data_md, data_md);
auto add_prim_desc = binary::primitive_desc(add_desc, eng);
Expand Down
7 changes: 5 additions & 2 deletions src/runtime/contrib/dnnl/dnnl_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <tvm/runtime/c_runtime_api.h>

#include <vector>

#include "dnnl.hpp"

namespace tvm {
Expand Down Expand Up @@ -54,13 +56,14 @@ extern "C" TVM_DLL void dnnl_fused_conv2d_bias_relu(float* data, float* weights,
extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_,
int p_O_);

extern "C" TVM_DLL void dnnl_relu(float* data, float* out, std::vector<long int> shape);
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, std::vector<int64_t> shape);

extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, float* new_mean, float* new_variance,
int p_n_, int p_c_, int p_h_, int p_w_, int p_e_);

extern "C" TVM_DLL void dnnl_binary_op(float* data, float* weight, float* out, int binary_algo, std::vector<long int> shape);
extern "C" TVM_DLL void dnnl_binary_op(float* data, float* weight, float* out, int binary_algo,
std::vector<int64_t> shape);

} // namespace contrib
} // namespace runtime
Expand Down

0 comments on commit da11a6c

Please sign in to comment.