From a22c6ed5ea6bf3030e9ce03b9936e1c289fb8df1 Mon Sep 17 00:00:00 2001 From: daquexian Date: Sun, 12 May 2019 21:42:21 +0800 Subject: [PATCH] Support batch input in model builder --- dnnlibrary/include/ModelBuilder.h | 5 +++-- dnnlibrary/src/DaqReader.cpp | 2 +- dnnlibrary/src/ModelBuilder.cpp | 5 +++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/dnnlibrary/include/ModelBuilder.h b/dnnlibrary/include/ModelBuilder.h index 715146d..21815f3 100644 --- a/dnnlibrary/include/ModelBuilder.h +++ b/dnnlibrary/include/ModelBuilder.h @@ -97,8 +97,9 @@ class ModelBuilder { Index GetBlobIndex(const std::string &blobName); Shape GetBlobDim(const std::string &blobName); Shape GetBlobDim(Index index); - Index AddInput(std::string name, const uint32_t height, - const uint32_t width, const uint32_t depth); + Index AddInput(std::string name, const uint32_t batch, + const uint32_t height, const uint32_t width, + const uint32_t depth); Index AddInput(std::string name, const android::nn::wrapper::OperandType &operand_type); // ModelBuilder auto generated methods start diff --git a/dnnlibrary/src/DaqReader.cpp b/dnnlibrary/src/DaqReader.cpp index a3c20bd..94f944d 100644 --- a/dnnlibrary/src/DaqReader.cpp +++ b/dnnlibrary/src/DaqReader.cpp @@ -170,7 +170,7 @@ void AddInputs(const DNN::Model &model, ModelBuilder &builder) { quant_info.zero_point_.value_or(0)); builder.AddInput(input_name, operand_type); } else { - builder.AddInput(input_name, shape[1], shape[2], shape[3]); + builder.AddInput(input_name, shape[0], shape[1], shape[2], shape[3]); } } } diff --git a/dnnlibrary/src/ModelBuilder.cpp b/dnnlibrary/src/ModelBuilder.cpp index 0a0839b..b319179 100644 --- a/dnnlibrary/src/ModelBuilder.cpp +++ b/dnnlibrary/src/ModelBuilder.cpp @@ -50,10 +50,11 @@ void ModelBuilder::RegisterOperand(const std::string &name, operand_types_.insert({name, operand_type}); } -ModelBuilder::Index ModelBuilder::AddInput(string name, const uint32_t height, +ModelBuilder::Index ModelBuilder::AddInput(string name, const uint32_t batch, + const uint32_t height, const uint32_t width, const uint32_t depth) { - const vector dimen{1, width, height, depth}; + const vector dimen{batch, width, height, depth}; return AddInput(name, {Type::TENSOR_FLOAT32, dimen}); }