Skip to content

Commit

Permalink
Support batch input in model builder
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed May 13, 2019
1 parent da2565d commit a22c6ed
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
5 changes: 3 additions & 2 deletions dnnlibrary/include/ModelBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ class ModelBuilder {
Index GetBlobIndex(const std::string &blobName);
Shape GetBlobDim(const std::string &blobName);
Shape GetBlobDim(Index index);
Index AddInput(std::string name, const uint32_t height,
const uint32_t width, const uint32_t depth);
Index AddInput(std::string name, const uint32_t batch,
const uint32_t height, const uint32_t width,
const uint32_t depth);
Index AddInput(std::string name,
const android::nn::wrapper::OperandType &operand_type);
// ModelBuilder auto generated methods start
Expand Down
2 changes: 1 addition & 1 deletion dnnlibrary/src/DaqReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ void AddInputs(const DNN::Model &model, ModelBuilder &builder) {
quant_info.zero_point_.value_or(0));
builder.AddInput(input_name, operand_type);
} else {
builder.AddInput(input_name, shape[1], shape[2], shape[3]);
builder.AddInput(input_name, shape[0], shape[1], shape[2], shape[3]);
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions dnnlibrary/src/ModelBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ void ModelBuilder::RegisterOperand(const std::string &name,
operand_types_.insert({name, operand_type});
}

ModelBuilder::Index ModelBuilder::AddInput(string name, const uint32_t height,
ModelBuilder::Index ModelBuilder::AddInput(string name, const uint32_t batch,
const uint32_t height,
const uint32_t width,
const uint32_t depth) {
const vector<uint32_t> dimen{1, width, height, depth};
const vector<uint32_t> dimen{batch, width, height, depth};
return AddInput(name, {Type::TENSOR_FLOAT32, dimen});
}

Expand Down

0 comments on commit a22c6ed

Please sign in to comment.