Skip to content

Commit

Permalink
Use shape inference to check concat, softmax and global pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed Jul 16, 2019
1 parent ecc612d commit a724168
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
6 changes: 4 additions & 2 deletions include/tools/onnx2daq/OnnxConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@ class OnnxConverter {

void HandleInitializer();
std::vector<flatbuffers::Offset<DNN::Input>> GetInputOfOnnxModel();
std::vector<flatbuffers::Offset<flatbuffers::String>> GetOutputOfOnnxModel();
std::vector<flatbuffers::Offset<flatbuffers::String>>
GetOutputOfOnnxModel();
void ReadTableFile(const std::string &table_file);
std::vector<flatbuffers::Offset<DNN::QuantInfo>> ConvertQuantInfosToFbs();

std::pair<bool, std::string> IsNodeSupported(
const ONNX_NAMESPACE::ModelProto &model_proto,
const ONNX_NAMESPACE::NodeProto &node_proto) const;

void AddConv(const std::string &input_name, const std::vector<int> &strides,
Expand Down Expand Up @@ -184,7 +186,7 @@ class OnnxConverter {

public:
std::vector<std::vector<int>> GetSupportedNodes(
const ONNX_NAMESPACE::ModelProto &model);
ONNX_NAMESPACE::ModelProto model_proto);
void Convert(const std::string &model_str, const std::string &filepath,
const std::string &table_file = "");
void Convert(const ONNX_NAMESPACE::ModelProto &model,
Expand Down
56 changes: 53 additions & 3 deletions tools/onnx2daq/OnnxConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <common/helper.h>
#include <glog/logging.h>
#include <onnx/optimizer/optimize.h>
#include <onnx/shape_inference/implementation.h>
#include "NodeAttrHelper.h"

using std::string;
Expand Down Expand Up @@ -1032,7 +1033,40 @@ void OnnxConverter::Convert(const std::string &model_str,
Save(filepath);
}

dnn::optional<Shaper::Shape> GetShape(
const ONNX_NAMESPACE::ModelProto &model_proto, const std::string &name) {
for (const auto &value_info : model_proto.graph().value_info()) {
if (value_info.name() == name) {
if (!value_info.has_type()) {
return dnn::nullopt;
} else if (!value_info.type().has_tensor_type()) {
return dnn::nullopt;
} else if (!value_info.type().tensor_type().has_shape()) {
return dnn::nullopt;
} else if (value_info.type().tensor_type().shape().dim_size() ==
0) {
return dnn::nullopt;
}

Shape shape;
for (const auto &dim :
value_info.type().tensor_type().shape().dim()) {
if (dim.has_dim_value()) {
shape.push_back(dim.dim_value());
} else {
return dnn::nullopt;
}
}

return shape;
}
}

return dnn::nullopt;
}

std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
const ONNX_NAMESPACE::ModelProto &model_proto,
const ONNX_NAMESPACE::NodeProto &node) const {
NodeAttrHelper helper(node);
const auto &op = node.op_type();
Expand Down Expand Up @@ -1089,12 +1123,18 @@ std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
if (helper.get("ceil_mode", 0) == 1) {
return {false, "ceil_mode == 1 is not supported for pooling"};
}
if (helper.get("dilations", std::vector<int>{1, 1}) != std::vector<int>{1, 1}) {
if (helper.get("dilations", std::vector<int>{1, 1}) !=
std::vector<int>{1, 1}) {
return {false, "Dilations of pooling is not supported"};
}
if (node.output_size() != 1) {
return {false, "Argmax in maxpooling is not supported"};
}
} else if (op == "GlobalAveragePool" || op == "GlobalMaxPool") {
const auto &input_shape = GetShape(model_proto, node.input(0));
if (!input_shape.has_value() || input_shape.value().size() != 4) {
return {false, "Only rank-4 tensor is supported in " + op};
}
} else if (op == "PRelu") {
const auto slope_name = m(node.input(1));
if (onnx_tensors_.has(slope_name)) {
Expand Down Expand Up @@ -1159,6 +1199,15 @@ std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
if (axis != 1) {
return {false, "Only axis == 1 is supported in Softmax"};
}
const auto &input_shape = GetShape(model_proto, node.input(0));
if (!input_shape.has_value() || input_shape.value().size() != 4) {
return {false, "Only rank-4 tensor is supported in Softmax"};
}
} else if (op == "Concat") {
const auto &input_shape = GetShape(model_proto, node.input(0));
if (!input_shape.has_value() || input_shape.value().size() != 4) {
return {false, "Only rank-4 tensor is supported in Softmax"};
}
}
return {true, ""};
}
Expand All @@ -1181,8 +1230,9 @@ bool IsValidSupportedNodesVec(const std::vector<int> &supported_node_vec,
}

std::vector<std::vector<int>> OnnxConverter::GetSupportedNodes(
const ONNX_NAMESPACE::ModelProto &model_proto) {
ONNX_NAMESPACE::ModelProto model_proto) {
GOOGLE_PROTOBUF_VERIFY_VERSION;
ONNX_NAMESPACE::shape_inference::InferShapes(model_proto);
model_proto_ = model_proto;
HandleInitializer();

Expand All @@ -1192,7 +1242,7 @@ std::vector<std::vector<int>> OnnxConverter::GetSupportedNodes(
bool supported;
std::string error_msg;
std::tie(supported, error_msg) =
IsNodeSupported(model_proto.graph().node(i));
IsNodeSupported(model_proto, model_proto.graph().node(i));
if (supported) {
supported_node_vec.push_back(i);
} else {
Expand Down

0 comments on commit a724168

Please sign in to comment.